"""High-level public API: derived params, CMB Cls, matter Pk, distances,
tSZ Cl^yy. All functions take a :class:`classy_szlite.params.CosmoParams`.
"""
from __future__ import annotations
import numpy as np
import jax
import jax.numpy as jnp
from ._registry import get_emulator, DEFAULT_COSMO
from .params import CosmoParams, ProfileParamsA10
from .cosmology import build as build_cosmo_grids, get_pk, get_pknl, get_distances
from .hmf import build_halo_grids
from .power_spectrum import cl_yy_1h_2h, cl_yy_1h_trispectrum
jax.config.update("jax_enable_x64", True)
[docs]
def cosmo_to_dict(cosmo: CosmoParams) -> dict:
"""Convert ``CosmoParams`` to the emulator-style dict (with curly-brace key).
Values pass through as-is (float or jax.Array) so the returned dict is
JAX-traceable — pass a ``CosmoParams`` of tracers to ``jax.grad``.
"""
return {
"omega_b": cosmo.omega_b,
"omega_cdm": cosmo.omega_cdm,
"H0": cosmo.H0,
"tau_reio": cosmo.tau_reio,
"ln10^{10}A_s": cosmo.ln10_10_As,
"n_s": cosmo.n_s,
"m_ncdm": cosmo.m_ncdm,
"N_ur": cosmo.N_ur,
"fEDE": cosmo.fEDE,
"log10z_c": cosmo.log10z_c,
"thetai_scf": cosmo.thetai_scf,
"r": cosmo.r,
}
# ---------------------------------------------------------------------------
# Derived parameters (σ8, Ω_m, S8)
# ---------------------------------------------------------------------------
[docs]
def derived(cosmo: CosmoParams) -> dict:
"""Derived parameters: ``sigma_8``, ``Omega_m``, ``S8``.
Also returns the full 17-element DER emulator output as ``'der_full'``
(σ8 is at index 1; consult the CosmoPower DER training script for the
full list).
"""
em = get_emulator('der')
full = dict(DEFAULT_COSMO); full.update(cosmo_to_dict(cosmo))
p_in = {k: [full[k]] for k in em.parameters}
der = 10.0 ** np.asarray(em.predict(p_in)).flatten()
sigma_8 = float(der[1])
h = float(cosmo.H0) / 100.0
# Σmν = 3 × m_ncdm for ede-v2 ν convention (3 degenerate ν)
sum_mnu = 3.0 * float(cosmo.m_ncdm)
Omega_m = (float(cosmo.omega_b) + float(cosmo.omega_cdm)
+ sum_mnu / 93.14) / h ** 2
S8 = sigma_8 * (Omega_m / 0.3) ** 0.5
return {"sigma_8": sigma_8, "Omega_m": Omega_m, "S8": S8, "der_full": der}
# ---------------------------------------------------------------------------
# CMB angular power spectra
# ---------------------------------------------------------------------------
# TT, EE, PP emulators output log10(prefactored Cl); TE outputs Cl directly
# (it can be negative). Recovery factor: ede-v2 uses 1/ell² to get raw Cl.
_CMB_LOG_CONVENTION = {"tt": True, "ee": True, "pp": True, "te": False}
[docs]
def cl_TTTEEE(cosmo: CosmoParams,
spectra: tuple[str, ...] = ("tt", "te", "ee"),
ell_factor: bool = True,
ell_convention: str = "classy_szfast") -> dict:
"""CMB angular power spectra.
Returns a dict with keys ``'ell'`` and the requested spectra
(``'tt','te','ee'``). Values are **dimensionless** — multiply by
``Tcmb_uK² = (2.7255e6)²`` to convert to μK².
Parameters
----------
cosmo
``CosmoParams`` instance (the ede-v2 input vector).
spectra
Requested spectra, any subset of ``("tt", "te", "ee", "pp")``.
ell_factor
If ``True`` (default), return ``D_ℓ = ℓ(ℓ+1) Cℓ / (2π)``. If ``False``,
return the raw ``Cℓ``.
ell_convention
- ``"classy_szfast"`` (default) — output ``ℓ ∈ [2, ..., 9500]``, with
``Cℓ(ℓ) = pred[ℓ−2] / ℓ²``. This is what ``classy_szfast`` does inside
``calculate_cmb``; using it gives bit-identical Cls to the
``cobaya + classy_sz fast-mode`` stack that every published EDE chain
(including ACT-DR6 + Planck) was fit with.
- ``"emulator_modes"`` — output ``ℓ ∈ [1, ..., 9499]`` exactly as stored
in the ede-v2 ``modes`` metadata, with ``Cℓ(ℓ) = pred[ℓ−1] / ℓ²``. Use
this only if you intentionally want to disagree with the published
chains.
"""
if ell_convention not in ("classy_szfast", "emulator_modes"):
raise ValueError(
f"Unknown ell_convention {ell_convention!r}; expected "
"'classy_szfast' or 'emulator_modes'."
)
full = dict(DEFAULT_COSMO); full.update(cosmo_to_dict(cosmo))
out = {}
ell = None
for spec in spectra:
if spec not in _CMB_LOG_CONVENTION:
raise ValueError(f"Unknown spectrum {spec!r}. Pick from {tuple(_CMB_LOG_CONVENTION)}.")
em = get_emulator(spec)
p_in = {k: [full[k]] for k in em.parameters}
pred = np.asarray(em.predict(p_in)).flatten()
if _CMB_LOG_CONVENTION[spec]:
pred = 10.0 ** pred
out[spec] = pred
if ell is None:
modes = np.asarray(em.modes)
if ell_convention == "classy_szfast":
# pred[i] is treated as Cℓ × ℓ² with ℓ = i + 2 (drops ℓ = 1
# from the emulator's modes list).
ell = modes + 1
else:
ell = modes
# ede-v2 recovery factor: raw → Cl
factor_to_Cl = 1.0 / (ell ** 2)
for s in spectra:
out[s] = out[s] * factor_to_Cl
out["ell"] = ell
if ell_factor:
fac_dl = ell * (ell + 1) / (2.0 * np.pi)
for s in spectra:
out[s] = out[s] * fac_dl
return out
def cl_TTTEEE_jax(cosmo: CosmoParams | dict,
spectra: tuple[str, ...] = ("tt", "te", "ee"),
ell_factor: bool = True,
ell_convention: str = "classy_szfast"):
"""JAX-traceable, JIT-able variant of :func:`cl_TTTEEE`.
Accepts either a :class:`CosmoParams` (with possibly traced JAX scalars)
or a plain ``dict`` whose values are jnp arrays. Returns a dict of jnp
arrays. The emulator forward pass is already pure-JAX (see
:class:`classy_szlite._emulator.Emulator.predict`); this wrapper just
avoids the ``np.asarray`` round-trip in :func:`cl_TTTEEE` that would
break tracing.
See :func:`cl_TTTEEE` for the ``ell_convention`` argument.
"""
import jax.numpy as jnp
if ell_convention not in ("classy_szfast", "emulator_modes"):
raise ValueError(
f"Unknown ell_convention {ell_convention!r}."
)
if hasattr(cosmo, "_asdict") or hasattr(cosmo, "__dataclass_fields__"):
cosmo_dict = cosmo_to_dict(cosmo)
else:
cosmo_dict = dict(cosmo)
# Canonicalise the A_s key (callers commonly use ``ln10_10_As``).
if "ln10_10_As" in cosmo_dict and "ln10^{10}A_s" not in cosmo_dict:
cosmo_dict["ln10^{10}A_s"] = cosmo_dict.pop("ln10_10_As")
full = dict(DEFAULT_COSMO); full.update(cosmo_dict)
out = {}
ell = None
for spec in spectra:
if spec not in _CMB_LOG_CONVENTION:
raise ValueError(f"Unknown spectrum {spec!r}.")
em = get_emulator(spec)
p_in = {k: jnp.atleast_1d(jnp.asarray(full[k], dtype=jnp.float64))
for k in em.parameters}
pred = em.predict(p_in)
pred = jnp.atleast_1d(pred).reshape(-1)
if _CMB_LOG_CONVENTION[spec]:
pred = 10.0 ** pred
out[spec] = pred
if ell is None:
modes = jnp.asarray(em.modes, dtype=jnp.float64)
ell = modes + 1 if ell_convention == "classy_szfast" else modes
inv_ell2 = 1.0 / (ell ** 2)
for s in spectra:
out[s] = out[s] * inv_ell2
out["ell"] = ell
if ell_factor:
fac_dl = ell * (ell + 1) / (2.0 * jnp.pi)
for s in spectra:
out[s] = out[s] * fac_dl
return out
# ---------------------------------------------------------------------------
# Pk, Pnl, distances
# ---------------------------------------------------------------------------
[docs]
def Pk(cosmo: CosmoParams, z_arr):
"""Linear P(k, z) — returns ``(k, pk(z, k))``."""
return get_pk(cosmo_to_dict(cosmo), z_arr)
[docs]
def Pnl(cosmo: CosmoParams, z_arr):
"""Non-linear P(k, z) (HMcode) — returns ``(k, pk(z, k))``."""
return get_pknl(cosmo_to_dict(cosmo), z_arr)
[docs]
def distances(cosmo: CosmoParams, z_arr):
"""Returns ``(Hz, chi, Da)``. ``Hz`` is H(z)/c in 1/Mpc; distances in Mpc."""
return get_distances(cosmo_to_dict(cosmo), z_arr)
# ---------------------------------------------------------------------------
# tSZ Cl^yy (halo-model, Arnaud 2010 profile)
# ---------------------------------------------------------------------------
def _a10_setup_sbt(x_outSZ: float, c500_fiducial: float):
"""Build the truncated SBT machinery (non-JIT). Runs once per factory.
Returns ``(u_grid_jax, sb, log_s_grid)`` where ``sb`` is an mcfit
SphericalBessel object whose ``__call__`` uses jax primitives and
is JIT-traceable.
"""
import numpy as _np
import warnings as _warnings
import mcfit as _mcfit
u_max = float(c500_fiducial) * float(x_outSZ)
u_grid = _np.geomspace(1e-5, u_max, 256)
with _warnings.catch_warnings():
_warnings.filterwarnings("ignore", message="use backend='jax' if desired")
sb = _mcfit.SphericalBessel(u_grid, backend='jax')
return jnp.asarray(u_grid), sb, jnp.asarray(_np.log(_np.asarray(sb.y)))
def _a10_compute_g(sb, u_jax, gamma, alpha, beta):
"""JIT-traceable kernel evaluation + SBT. Returns g_table on log_s_grid."""
kernel = u_jax ** (-gamma) * (1.0 + u_jax ** alpha) ** ((gamma - beta) / alpha)
_, g = sb(kernel, extrap=False)
return g * jnp.sqrt(jnp.pi / 2.0)
[docs]
def cl_yy(cosmo: CosmoParams, profile: ProfileParamsA10, ell,
z_grid: jax.Array | None = None,
n_z: int = 100, m_min: float = 1e10, m_max: float = 3.5e15,
n_m: int = 200, delta_crit: float = 500.0,
x_outSZ: float = 4.0, c500_fiducial: float = 1.156):
"""Halo-model tSZ angular power spectrum (full pipeline per call).
Returns ``(cl_1h, cl_2h)`` — dimensionless C_ell. Multiply by
``ell*(ell+1)/(2π)*1e12`` to get ``D_ell × 1e12``.
The GNFW pressure profile is truncated at x = x_outSZ * r_500c
(default 4.0, matching classy_sz / Arnaud+2010 convention).
For MCMC sampling only profile parameters at fixed cosmology, use
:func:`cl_yy_factory` instead — ~3× faster.
"""
if z_grid is None:
z_grid = jnp.geomspace(0.005, 3.0, n_z)
cosmo_dict = cosmo_to_dict(cosmo)
cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid)
hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit,
m_min=m_min, m_max=m_max, n_m=n_m)
from .power_spectrum import _A10_GAMMA, _A10_ALPHA, _A10_BETA
pp_dict = profile._asdict()
u_jax, sb, log_s = _a10_setup_sbt(x_outSZ, c500_fiducial)
g_tab = _a10_compute_g(
sb, u_jax,
pp_dict.get('gamma', _A10_GAMMA),
pp_dict.get('alpha', _A10_ALPHA),
pp_dict.get('beta', _A10_BETA),
)
pp_dict = dict(pp_dict, _g_table=g_tab, _log_s_grid=log_s)
cl_1h, cl_2h = cl_yy_1h_2h(jnp.asarray(ell), cg, hg, cosmo_dict,
profile='arnaud10', profile_params=pp_dict)
return cl_1h, cl_2h
[docs]
def cl_yy_factory(cosmo: CosmoParams, ell,
z_grid: jax.Array | None = None,
n_z: int = 100, m_min: float = 1e10, m_max: float = 3.5e15,
n_m: int = 200, delta_crit: float = 500.0,
x_outSZ: float = 4.0, c500_fiducial: float = 1.156):
"""Fixed-cosmology fast-path: precompute the heavy bits, get a closure.
Builds ``CosmoGrids`` (emulators → P_lin, distances, σ(R)) and
``HaloGrids`` (Tinker 08 HMF, bias) **once**, then returns:
ev(profile) -> (cl_1h, cl_2h)
A subsequent ``ev(profile)`` call only runs the ``cl_yy_1h_2h``
halo-model integration — typically ~5 ms per call. Intended for MCMC
over profile / nuisance parameters with fixed cosmology.
Parameters
----------
x_outSZ : float, optional
Outer truncation radius of the GNFW pressure profile in units of
r_500c (i.e. x = r / r_500c). The FT look-up table u-grid runs
from 1e-5 to ``c500_fiducial * x_outSZ`` (in u = c500 * x units),
matching the classy_sz ``x_outSZ`` convention. Default 4.0
(literature / ACT-DR6 may26 convention).
c500_fiducial : float, optional
c500 used to convert x_outSZ → u_max at table-build time.
Should match the c500 you will pass in ProfileParamsA10.
Default 1.156 (Arnaud et al. 2010).
"""
# classy_sz sets the GNFW profile to exactly zero for r/r_500c > x_outSZ
# before FFT-transforming. We replicate this by building a truncated
# u-grid (1e-5 to c500_fiducial * x_outSZ, 256 pts) and transforming
# with extrap=False — see :func:`_build_a10_truncated_g` for details.
# Truncated grid + extrap=False replicates classy_sz's x_outSZ truncation
# to within <1% (beta>5), <3% (beta~3), <8% (beta~1.7).
if z_grid is None:
z_grid = jnp.geomspace(0.005, 3.0, n_z)
cosmo_dict = cosmo_to_dict(cosmo)
cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid)
hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit,
m_min=m_min, m_max=m_max, n_m=n_m)
ell_jax = jnp.asarray(ell)
from .power_spectrum import _A10_GAMMA, _A10_ALPHA, _A10_BETA
# Build the truncated SBT machinery once (outside @jax.jit); the JIT'd
# evaluate then only does the kernel + SBT call, both jax-primitive.
u_jax, sb, log_s = _a10_setup_sbt(x_outSZ, c500_fiducial)
@jax.jit
def evaluate(profile: ProfileParamsA10):
pp = profile._asdict()
g_tab = _a10_compute_g(
sb, u_jax,
pp.get('gamma', _A10_GAMMA),
pp.get('alpha', _A10_ALPHA),
pp.get('beta', _A10_BETA),
)
pp_ext = dict(pp, _g_table=g_tab, _log_s_grid=log_s)
cl_1h, cl_2h = cl_yy_1h_2h(
ell_jax, cg, hg, cosmo_dict,
profile='arnaud10', profile_params=pp_ext,
)
return cl_1h, cl_2h
return evaluate
[docs]
def cl_yy_trispectrum(cosmo: CosmoParams, profile: ProfileParamsA10, ell,
z_grid: jax.Array | None = None,
n_z: int = 100, m_min: float = 1e10, m_max: float = 3.5e15,
n_m: int = 200, delta_crit: float = 500.0) -> jax.Array:
"""1-halo connected tSZ trispectrum :math:`T^{1h}(\\ell, \\ell')`.
Symmetric ``(n_ell, n_ell)`` matrix. Used to construct the
non-Gaussian part of the bandpower covariance; see
:func:`cl_yy_covariance`.
Same per-call cost as :func:`cl_yy` for the cosmology grids, plus an
extra ``O(n_ell² n_z n_m)`` integral for the trispectrum contraction.
"""
if z_grid is None:
z_grid = jnp.geomspace(0.005, 3.0, n_z)
cosmo_dict = cosmo_to_dict(cosmo)
cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid)
hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit,
m_min=m_min, m_max=m_max, n_m=n_m)
return cl_yy_1h_trispectrum(jnp.asarray(ell), cg, hg, cosmo_dict,
profile='arnaud10',
profile_params=profile._asdict())
[docs]
def cl_yy_covariance(cosmo: CosmoParams, profile: ProfileParamsA10, ell,
delta_ell, fsky: float = 1.0,
include_trispectrum: bool = True,
z_grid: jax.Array | None = None,
n_z: int = 100, m_min: float = 1e10,
m_max: float = 3.5e15, n_m: int = 200,
delta_crit: float = 500.0) -> jax.Array:
"""Bandpower covariance for tSZ :math:`C_\\ell^{yy}`.
.. math::
\\mathrm{Cov}(C_\\ell, C_{\\ell'}) =
\\frac{2\\,C_\\ell^2}{(2\\ell + 1)\\,\\Delta\\ell\\,f_\\mathrm{sky}}
\\,\\delta_{\\ell\\ell'}
\\;+\\; \\frac{T^{1h}(\\ell, \\ell')}{4\\pi\\,f_\\mathrm{sky}}
Returns the full :math:`(n_\\ell, n_\\ell)` covariance matrix
suitable for a Cholesky decomposition to generate synthetic
bandpower realisations:
>>> L = jnp.linalg.cholesky(cov)
>>> y_synth = y_fid + L @ jax.random.normal(key, (len(ell),))
The covariance is on the dimensionless :math:`C_\\ell`. If the data
vector is in :math:`D_\\ell\\times 10^{12}` units, rescale by the
outer product of ``ell(ell+1)/(2π) × 1e12`` before Cholesky.
Parameters
----------
cosmo, profile, ell : as for :func:`cl_yy`.
delta_ell : float or array_like
Bandpower width(s) :math:`\\Delta\\ell`. Scalar broadcasts to all
bins.
fsky : float
Observed sky fraction; the Gaussian variance scales as
:math:`1/f_\\mathrm{sky}` and the trispectrum term as
:math:`1/(4\\pi f_\\mathrm{sky})`.
include_trispectrum : bool
If False, return Gaussian variance only (diagonal).
z_grid, n_z, m_min, m_max, n_m, delta_crit
Forwarded to the cosmology / halo-model grid builders.
"""
ell_arr = jnp.asarray(ell)
delta_ell_arr = jnp.broadcast_to(jnp.asarray(delta_ell, dtype=ell_arr.dtype),
ell_arr.shape)
# Build the cosmology and halo grids ONCE and reuse for both cl_yy
# and the trispectrum. Avoids the 2× emulator + HMF build cost the
# naive composition cl_yy(...) + cl_yy_trispectrum(...) would pay.
if z_grid is None:
z_grid = jnp.geomspace(0.005, 3.0, n_z)
cosmo_dict = cosmo_to_dict(cosmo)
cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid)
hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit,
m_min=m_min, m_max=m_max, n_m=n_m)
pp_dict = profile._asdict()
cl_1h, cl_2h = cl_yy_1h_2h(ell_arr, cg, hg, cosmo_dict,
profile='arnaud10', profile_params=pp_dict)
cl_tot = cl_1h + cl_2h
gaussian_diag = 2.0 * cl_tot ** 2 / ((2.0 * ell_arr + 1.0)
* delta_ell_arr * fsky)
cov = jnp.diag(gaussian_diag)
if include_trispectrum:
T = cl_yy_1h_trispectrum(ell_arr, cg, hg, cosmo_dict,
profile='arnaud10', profile_params=pp_dict)
cov = cov + T / (4.0 * jnp.pi * fsky)
return cov