Source code for classy_szlite.api

"""High-level public API: derived params, CMB Cls, matter Pk, distances,
tSZ Cl^yy. All functions take a :class:`classy_szlite.params.CosmoParams`.
"""
from __future__ import annotations

import numpy as np
import jax
import jax.numpy as jnp

from ._registry import get_emulator, DEFAULT_COSMO
from .params import CosmoParams, ProfileParamsA10
from .cosmology import build as build_cosmo_grids, get_pk, get_pknl, get_distances
from .hmf import build_halo_grids
from .power_spectrum import cl_yy_1h_2h, cl_yy_1h_trispectrum

jax.config.update("jax_enable_x64", True)


[docs] def cosmo_to_dict(cosmo: CosmoParams) -> dict: """Convert ``CosmoParams`` to the emulator-style dict (with curly-brace key). Values pass through as-is (float or jax.Array) so the returned dict is JAX-traceable — pass a ``CosmoParams`` of tracers to ``jax.grad``. """ return { "omega_b": cosmo.omega_b, "omega_cdm": cosmo.omega_cdm, "H0": cosmo.H0, "tau_reio": cosmo.tau_reio, "ln10^{10}A_s": cosmo.ln10_10_As, "n_s": cosmo.n_s, "m_ncdm": cosmo.m_ncdm, "N_ur": cosmo.N_ur, "fEDE": cosmo.fEDE, "log10z_c": cosmo.log10z_c, "thetai_scf": cosmo.thetai_scf, "r": cosmo.r, }
# --------------------------------------------------------------------------- # Derived parameters (σ8, Ω_m, S8) # ---------------------------------------------------------------------------
[docs] def derived(cosmo: CosmoParams) -> dict: """Derived parameters: ``sigma_8``, ``Omega_m``, ``S8``. Also returns the full 17-element DER emulator output as ``'der_full'`` (σ8 is at index 1; consult the CosmoPower DER training script for the full list). """ em = get_emulator('der') full = dict(DEFAULT_COSMO); full.update(cosmo_to_dict(cosmo)) p_in = {k: [full[k]] for k in em.parameters} der = 10.0 ** np.asarray(em.predict(p_in)).flatten() sigma_8 = float(der[1]) h = float(cosmo.H0) / 100.0 # Σmν = 3 × m_ncdm for ede-v2 ν convention (3 degenerate ν) sum_mnu = 3.0 * float(cosmo.m_ncdm) Omega_m = (float(cosmo.omega_b) + float(cosmo.omega_cdm) + sum_mnu / 93.14) / h ** 2 S8 = sigma_8 * (Omega_m / 0.3) ** 0.5 return {"sigma_8": sigma_8, "Omega_m": Omega_m, "S8": S8, "der_full": der}
# --------------------------------------------------------------------------- # CMB angular power spectra # --------------------------------------------------------------------------- # TT, EE, PP emulators output log10(prefactored Cl); TE outputs Cl directly # (it can be negative). Recovery factor: ede-v2 uses 1/ell² to get raw Cl. _CMB_LOG_CONVENTION = {"tt": True, "ee": True, "pp": True, "te": False}
[docs] def cl_TTTEEE(cosmo: CosmoParams, spectra: tuple[str, ...] = ("tt", "te", "ee"), ell_factor: bool = True, ell_convention: str = "classy_szfast") -> dict: """CMB angular power spectra. Returns a dict with keys ``'ell'`` and the requested spectra (``'tt','te','ee'``). Values are **dimensionless** — multiply by ``Tcmb_uK² = (2.7255e6)²`` to convert to μK². Parameters ---------- cosmo ``CosmoParams`` instance (the ede-v2 input vector). spectra Requested spectra, any subset of ``("tt", "te", "ee", "pp")``. ell_factor If ``True`` (default), return ``D_ℓ = ℓ(ℓ+1) Cℓ / (2π)``. If ``False``, return the raw ``Cℓ``. ell_convention - ``"classy_szfast"`` (default) — output ``ℓ ∈ [2, ..., 9500]``, with ``Cℓ(ℓ) = pred[ℓ−2] / ℓ²``. This is what ``classy_szfast`` does inside ``calculate_cmb``; using it gives bit-identical Cls to the ``cobaya + classy_sz fast-mode`` stack that every published EDE chain (including ACT-DR6 + Planck) was fit with. - ``"emulator_modes"`` — output ``ℓ ∈ [1, ..., 9499]`` exactly as stored in the ede-v2 ``modes`` metadata, with ``Cℓ(ℓ) = pred[ℓ−1] / ℓ²``. Use this only if you intentionally want to disagree with the published chains. """ if ell_convention not in ("classy_szfast", "emulator_modes"): raise ValueError( f"Unknown ell_convention {ell_convention!r}; expected " "'classy_szfast' or 'emulator_modes'." ) full = dict(DEFAULT_COSMO); full.update(cosmo_to_dict(cosmo)) out = {} ell = None for spec in spectra: if spec not in _CMB_LOG_CONVENTION: raise ValueError(f"Unknown spectrum {spec!r}. Pick from {tuple(_CMB_LOG_CONVENTION)}.") em = get_emulator(spec) p_in = {k: [full[k]] for k in em.parameters} pred = np.asarray(em.predict(p_in)).flatten() if _CMB_LOG_CONVENTION[spec]: pred = 10.0 ** pred out[spec] = pred if ell is None: modes = np.asarray(em.modes) if ell_convention == "classy_szfast": # pred[i] is treated as Cℓ × ℓ² with ℓ = i + 2 (drops ℓ = 1 # from the emulator's modes list). ell = modes + 1 else: ell = modes # ede-v2 recovery factor: raw → Cl factor_to_Cl = 1.0 / (ell ** 2) for s in spectra: out[s] = out[s] * factor_to_Cl out["ell"] = ell if ell_factor: fac_dl = ell * (ell + 1) / (2.0 * np.pi) for s in spectra: out[s] = out[s] * fac_dl return out
def cl_TTTEEE_jax(cosmo: CosmoParams | dict, spectra: tuple[str, ...] = ("tt", "te", "ee"), ell_factor: bool = True, ell_convention: str = "classy_szfast"): """JAX-traceable, JIT-able variant of :func:`cl_TTTEEE`. Accepts either a :class:`CosmoParams` (with possibly traced JAX scalars) or a plain ``dict`` whose values are jnp arrays. Returns a dict of jnp arrays. The emulator forward pass is already pure-JAX (see :class:`classy_szlite._emulator.Emulator.predict`); this wrapper just avoids the ``np.asarray`` round-trip in :func:`cl_TTTEEE` that would break tracing. See :func:`cl_TTTEEE` for the ``ell_convention`` argument. """ import jax.numpy as jnp if ell_convention not in ("classy_szfast", "emulator_modes"): raise ValueError( f"Unknown ell_convention {ell_convention!r}." ) if hasattr(cosmo, "_asdict") or hasattr(cosmo, "__dataclass_fields__"): cosmo_dict = cosmo_to_dict(cosmo) else: cosmo_dict = dict(cosmo) # Canonicalise the A_s key (callers commonly use ``ln10_10_As``). if "ln10_10_As" in cosmo_dict and "ln10^{10}A_s" not in cosmo_dict: cosmo_dict["ln10^{10}A_s"] = cosmo_dict.pop("ln10_10_As") full = dict(DEFAULT_COSMO); full.update(cosmo_dict) out = {} ell = None for spec in spectra: if spec not in _CMB_LOG_CONVENTION: raise ValueError(f"Unknown spectrum {spec!r}.") em = get_emulator(spec) p_in = {k: jnp.atleast_1d(jnp.asarray(full[k], dtype=jnp.float64)) for k in em.parameters} pred = em.predict(p_in) pred = jnp.atleast_1d(pred).reshape(-1) if _CMB_LOG_CONVENTION[spec]: pred = 10.0 ** pred out[spec] = pred if ell is None: modes = jnp.asarray(em.modes, dtype=jnp.float64) ell = modes + 1 if ell_convention == "classy_szfast" else modes inv_ell2 = 1.0 / (ell ** 2) for s in spectra: out[s] = out[s] * inv_ell2 out["ell"] = ell if ell_factor: fac_dl = ell * (ell + 1) / (2.0 * jnp.pi) for s in spectra: out[s] = out[s] * fac_dl return out # --------------------------------------------------------------------------- # Pk, Pnl, distances # ---------------------------------------------------------------------------
[docs] def Pk(cosmo: CosmoParams, z_arr): """Linear P(k, z) — returns ``(k, pk(z, k))``.""" return get_pk(cosmo_to_dict(cosmo), z_arr)
[docs] def Pnl(cosmo: CosmoParams, z_arr): """Non-linear P(k, z) (HMcode) — returns ``(k, pk(z, k))``.""" return get_pknl(cosmo_to_dict(cosmo), z_arr)
[docs] def distances(cosmo: CosmoParams, z_arr): """Returns ``(Hz, chi, Da)``. ``Hz`` is H(z)/c in 1/Mpc; distances in Mpc.""" return get_distances(cosmo_to_dict(cosmo), z_arr)
# --------------------------------------------------------------------------- # tSZ Cl^yy (halo-model, Arnaud 2010 profile) # --------------------------------------------------------------------------- def _a10_setup_sbt(x_outSZ: float, c500_fiducial: float): """Build the truncated SBT machinery (non-JIT). Runs once per factory. Returns ``(u_grid_jax, sb, log_s_grid)`` where ``sb`` is an mcfit SphericalBessel object whose ``__call__`` uses jax primitives and is JIT-traceable. """ import numpy as _np import warnings as _warnings import mcfit as _mcfit u_max = float(c500_fiducial) * float(x_outSZ) u_grid = _np.geomspace(1e-5, u_max, 256) with _warnings.catch_warnings(): _warnings.filterwarnings("ignore", message="use backend='jax' if desired") sb = _mcfit.SphericalBessel(u_grid, backend='jax') return jnp.asarray(u_grid), sb, jnp.asarray(_np.log(_np.asarray(sb.y))) def _a10_compute_g(sb, u_jax, gamma, alpha, beta): """JIT-traceable kernel evaluation + SBT. Returns g_table on log_s_grid.""" kernel = u_jax ** (-gamma) * (1.0 + u_jax ** alpha) ** ((gamma - beta) / alpha) _, g = sb(kernel, extrap=False) return g * jnp.sqrt(jnp.pi / 2.0)
[docs] def cl_yy(cosmo: CosmoParams, profile: ProfileParamsA10, ell, z_grid: jax.Array | None = None, n_z: int = 100, m_min: float = 1e10, m_max: float = 3.5e15, n_m: int = 200, delta_crit: float = 500.0, x_outSZ: float = 4.0, c500_fiducial: float = 1.156): """Halo-model tSZ angular power spectrum (full pipeline per call). Returns ``(cl_1h, cl_2h)`` — dimensionless C_ell. Multiply by ``ell*(ell+1)/(2π)*1e12`` to get ``D_ell × 1e12``. The GNFW pressure profile is truncated at x = x_outSZ * r_500c (default 4.0, matching classy_sz / Arnaud+2010 convention). For MCMC sampling only profile parameters at fixed cosmology, use :func:`cl_yy_factory` instead — ~3× faster. """ if z_grid is None: z_grid = jnp.geomspace(0.005, 3.0, n_z) cosmo_dict = cosmo_to_dict(cosmo) cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid) hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit, m_min=m_min, m_max=m_max, n_m=n_m) from .power_spectrum import _A10_GAMMA, _A10_ALPHA, _A10_BETA pp_dict = profile._asdict() u_jax, sb, log_s = _a10_setup_sbt(x_outSZ, c500_fiducial) g_tab = _a10_compute_g( sb, u_jax, pp_dict.get('gamma', _A10_GAMMA), pp_dict.get('alpha', _A10_ALPHA), pp_dict.get('beta', _A10_BETA), ) pp_dict = dict(pp_dict, _g_table=g_tab, _log_s_grid=log_s) cl_1h, cl_2h = cl_yy_1h_2h(jnp.asarray(ell), cg, hg, cosmo_dict, profile='arnaud10', profile_params=pp_dict) return cl_1h, cl_2h
[docs] def cl_yy_factory(cosmo: CosmoParams, ell, z_grid: jax.Array | None = None, n_z: int = 100, m_min: float = 1e10, m_max: float = 3.5e15, n_m: int = 200, delta_crit: float = 500.0, x_outSZ: float = 4.0, c500_fiducial: float = 1.156): """Fixed-cosmology fast-path: precompute the heavy bits, get a closure. Builds ``CosmoGrids`` (emulators → P_lin, distances, σ(R)) and ``HaloGrids`` (Tinker 08 HMF, bias) **once**, then returns: ev(profile) -> (cl_1h, cl_2h) A subsequent ``ev(profile)`` call only runs the ``cl_yy_1h_2h`` halo-model integration — typically ~5 ms per call. Intended for MCMC over profile / nuisance parameters with fixed cosmology. Parameters ---------- x_outSZ : float, optional Outer truncation radius of the GNFW pressure profile in units of r_500c (i.e. x = r / r_500c). The FT look-up table u-grid runs from 1e-5 to ``c500_fiducial * x_outSZ`` (in u = c500 * x units), matching the classy_sz ``x_outSZ`` convention. Default 4.0 (literature / ACT-DR6 may26 convention). c500_fiducial : float, optional c500 used to convert x_outSZ → u_max at table-build time. Should match the c500 you will pass in ProfileParamsA10. Default 1.156 (Arnaud et al. 2010). """ # classy_sz sets the GNFW profile to exactly zero for r/r_500c > x_outSZ # before FFT-transforming. We replicate this by building a truncated # u-grid (1e-5 to c500_fiducial * x_outSZ, 256 pts) and transforming # with extrap=False — see :func:`_build_a10_truncated_g` for details. # Truncated grid + extrap=False replicates classy_sz's x_outSZ truncation # to within <1% (beta>5), <3% (beta~3), <8% (beta~1.7). if z_grid is None: z_grid = jnp.geomspace(0.005, 3.0, n_z) cosmo_dict = cosmo_to_dict(cosmo) cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid) hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit, m_min=m_min, m_max=m_max, n_m=n_m) ell_jax = jnp.asarray(ell) from .power_spectrum import _A10_GAMMA, _A10_ALPHA, _A10_BETA # Build the truncated SBT machinery once (outside @jax.jit); the JIT'd # evaluate then only does the kernel + SBT call, both jax-primitive. u_jax, sb, log_s = _a10_setup_sbt(x_outSZ, c500_fiducial) @jax.jit def evaluate(profile: ProfileParamsA10): pp = profile._asdict() g_tab = _a10_compute_g( sb, u_jax, pp.get('gamma', _A10_GAMMA), pp.get('alpha', _A10_ALPHA), pp.get('beta', _A10_BETA), ) pp_ext = dict(pp, _g_table=g_tab, _log_s_grid=log_s) cl_1h, cl_2h = cl_yy_1h_2h( ell_jax, cg, hg, cosmo_dict, profile='arnaud10', profile_params=pp_ext, ) return cl_1h, cl_2h return evaluate
[docs] def cl_yy_trispectrum(cosmo: CosmoParams, profile: ProfileParamsA10, ell, z_grid: jax.Array | None = None, n_z: int = 100, m_min: float = 1e10, m_max: float = 3.5e15, n_m: int = 200, delta_crit: float = 500.0) -> jax.Array: """1-halo connected tSZ trispectrum :math:`T^{1h}(\\ell, \\ell')`. Symmetric ``(n_ell, n_ell)`` matrix. Used to construct the non-Gaussian part of the bandpower covariance; see :func:`cl_yy_covariance`. Same per-call cost as :func:`cl_yy` for the cosmology grids, plus an extra ``O(n_ell² n_z n_m)`` integral for the trispectrum contraction. """ if z_grid is None: z_grid = jnp.geomspace(0.005, 3.0, n_z) cosmo_dict = cosmo_to_dict(cosmo) cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid) hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit, m_min=m_min, m_max=m_max, n_m=n_m) return cl_yy_1h_trispectrum(jnp.asarray(ell), cg, hg, cosmo_dict, profile='arnaud10', profile_params=profile._asdict())
[docs] def cl_yy_covariance(cosmo: CosmoParams, profile: ProfileParamsA10, ell, delta_ell, fsky: float = 1.0, include_trispectrum: bool = True, z_grid: jax.Array | None = None, n_z: int = 100, m_min: float = 1e10, m_max: float = 3.5e15, n_m: int = 200, delta_crit: float = 500.0) -> jax.Array: """Bandpower covariance for tSZ :math:`C_\\ell^{yy}`. .. math:: \\mathrm{Cov}(C_\\ell, C_{\\ell'}) = \\frac{2\\,C_\\ell^2}{(2\\ell + 1)\\,\\Delta\\ell\\,f_\\mathrm{sky}} \\,\\delta_{\\ell\\ell'} \\;+\\; \\frac{T^{1h}(\\ell, \\ell')}{4\\pi\\,f_\\mathrm{sky}} Returns the full :math:`(n_\\ell, n_\\ell)` covariance matrix suitable for a Cholesky decomposition to generate synthetic bandpower realisations: >>> L = jnp.linalg.cholesky(cov) >>> y_synth = y_fid + L @ jax.random.normal(key, (len(ell),)) The covariance is on the dimensionless :math:`C_\\ell`. If the data vector is in :math:`D_\\ell\\times 10^{12}` units, rescale by the outer product of ``ell(ell+1)/(2π) × 1e12`` before Cholesky. Parameters ---------- cosmo, profile, ell : as for :func:`cl_yy`. delta_ell : float or array_like Bandpower width(s) :math:`\\Delta\\ell`. Scalar broadcasts to all bins. fsky : float Observed sky fraction; the Gaussian variance scales as :math:`1/f_\\mathrm{sky}` and the trispectrum term as :math:`1/(4\\pi f_\\mathrm{sky})`. include_trispectrum : bool If False, return Gaussian variance only (diagonal). z_grid, n_z, m_min, m_max, n_m, delta_crit Forwarded to the cosmology / halo-model grid builders. """ ell_arr = jnp.asarray(ell) delta_ell_arr = jnp.broadcast_to(jnp.asarray(delta_ell, dtype=ell_arr.dtype), ell_arr.shape) # Build the cosmology and halo grids ONCE and reuse for both cl_yy # and the trispectrum. Avoids the 2× emulator + HMF build cost the # naive composition cl_yy(...) + cl_yy_trispectrum(...) would pay. if z_grid is None: z_grid = jnp.geomspace(0.005, 3.0, n_z) cosmo_dict = cosmo_to_dict(cosmo) cg = build_cosmo_grids(cosmo_dict, z_grid=z_grid) hg = build_halo_grids(cg, cosmo_dict, delta_crit=delta_crit, m_min=m_min, m_max=m_max, n_m=n_m) pp_dict = profile._asdict() cl_1h, cl_2h = cl_yy_1h_2h(ell_arr, cg, hg, cosmo_dict, profile='arnaud10', profile_params=pp_dict) cl_tot = cl_1h + cl_2h gaussian_diag = 2.0 * cl_tot ** 2 / ((2.0 * ell_arr + 1.0) * delta_ell_arr * fsky) cov = jnp.diag(gaussian_diag) if include_trispectrum: T = cl_yy_1h_trispectrum(ell_arr, cg, hg, cosmo_dict, profile='arnaud10', profile_params=pp_dict) cov = cov + T / (4.0 * jnp.pi * fsky) return cov