Skip to content

Commit

Permalink
Merge pull request #74 from 21cmfast/emu_h1c_lk
Browse files Browse the repository at this point in the history
Emulator integration + HERA H1C PS Upper limit likelihood
  • Loading branch information
DanielaBreitman authored Jul 27, 2023
2 parents 8d959c1 + 63512d2 commit 45805ab
Show file tree
Hide file tree
Showing 15 changed files with 549 additions and 96 deletions.
12 changes: 8 additions & 4 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ ignore =
W503
F403
F401
N803 # Naming upper/lowercase -- too hard right now.
N806 # Naming upper/lowercase -- too hard right now.
N802 # Naming upper/lowercase -- too hard right now.
D401 # Docstring in imperative mood. This should *not* be the case for @property's, but can't ignore them atm.
# Naming upper/lowercase -- too hard right now.
N803
# Naming upper/lowercase -- too hard right now.
N806
# Naming upper/lowercase -- too hard right now.
N802
# Docstring in imperative mood. This should *not* be the case for @property's, but can't ignore them atm.
D401
max-line-length = 88
max-complexity = 21
docstring-convention=numpy
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/test_suite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: [3.7, 3.8, 3.9, "3.10"]
python-version: [3.8, 3.9, "3.10"]
defaults:
run:
# Adding -l {0} ensures conda can be found properly in each step
Expand Down Expand Up @@ -91,6 +91,10 @@ jobs:
# export PATH="$HOME/miniconda/bin:$PATH"
# source activate $ENV_NAME
# CC=gcc CFLAGS="-isysroot /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk" pip install .
- name: List Environment
run: |
conda list
- name: Run Tests
run: |
python -m pytest --cov=py21cmmc --cov-config=.coveragerc -vv --cov-report xml:./coverage.xml --durations=25
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: '(^docs/conf.py|^user_data/External_tables/)'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand All @@ -17,7 +17,7 @@ repos:
- id: mixed-line-ending
args: ['--fix=no']
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4 # pick a git hash / tag to point to
rev: 6.0.0 # pick a git hash / tag to point to
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -29,10 +29,10 @@ repos:
- flake8-copyright
- flake8-docstrings
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
Empty file modified ci/install_conda.sh
100755 → 100644
Empty file.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ def _find_version(*file_paths):
"click",
"numpy",
"cosmoHammer",
"scipy",
"scipy", # Astropy<5.2.1 breaks for scipy>=1.11. Can remove this later.
"matplotlib>=2.1",
"emcee<3",
"powerbox>=0.5.7",
"cached_property",
"21cmFAST",
"pymultinest",
"py21cmemu>=1.0.8",
"astropy>=5.2.1",
],
extras_require={
"samplers": ["pymultinest"],
Expand Down
2 changes: 2 additions & 0 deletions src/py21cmmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__version__ = "1.0.0dev3"
from .analyse import get_samples, load_primitive_chain
from .core import (
Core21cmEMU,
CoreCMB,
CoreCoevalModule,
CoreForest,
Expand All @@ -15,6 +16,7 @@
from .likelihood import (
Likelihood1DPowerCoeval,
Likelihood1DPowerLightcone,
Likelihood1DPowerLightconeUpper,
LikelihoodBaseFile,
LikelihoodEDGES,
LikelihoodForest,
Expand Down
144 changes: 141 additions & 3 deletions src/py21cmmc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __eq__(self, other):
args = tuple(set(args))

for arg in args + self._extra_defining_attributes:

if arg == "self" or arg in self._ignore_attributes:
continue

Expand Down Expand Up @@ -669,7 +668,6 @@ def convert_model_to_mock(self, ctx):
try:
lfunc[i] += np.random.normal(loc=0, scale=s(muv), size=len(lfunc[i]))
except TypeError:

lfunc[i] += np.random.normal(loc=0, scale=s, size=len(lfunc[i]))


Expand Down Expand Up @@ -941,7 +939,6 @@ def __init__(
global_params=None,
**io_options,
):

super().__init__(io_options.get("store", None))

if not use_21cmfast:
Expand Down Expand Up @@ -1151,3 +1148,144 @@ def get_cl(self, cosmo, l_max=-1):
elif key in ["tp", "ep"]:
cl[key] *= T * 1.0e6
return cl


class Core21cmEMU(CoreBase):
r"""A Core Module that loads 21cmEMU and uses it to obtain 21cmFAST summaries.
Notes
-----
This core calls 21cmEMU and uses it to evaluate 21cmFAST summaries (power spectrum, global signal, neutral fraction, spin temperature)
given a set of astro_params.
Parameters
----------
redshift : float or array_like
The redshift(s) at which to evaluate the summary statistics.
astro_params : dict or :class:`~py21cmfast.AstroParams`
Astrophysical parameters of reionization model according to Park+19 parametrization.
version : str, optional
Emulator version to use, defaults to 'latest'.
"""

def __init__(
self,
astro_params=None,
redshift=None,
k=None,
name="",
global_params=None,
ctx_variables=(
"Tb",
"Tb_err",
"Ts",
"Ts_err",
"xHI",
"xHI_err",
"redshifts",
"PS_redshifts",
"PS",
"PS_err",
"Muv",
"UVLFs",
"UVLFs_err",
"UVLF_redshifts",
"k",
"tau",
"tau_err",
),
cache_dir=None,
version="latest",
store=[],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.name = str(name)
self.ctx_variables = ctx_variables

try:
from py21cmemu import Emulator, properties
except:
print("Could not load py21cmemu. Make sure it is installed properly.")
self.astro_param_keys = (
"F_STAR10",
"ALPHA_STAR",
"F_ESC10",
"ALPHA_ESC",
"M_TURN",
"t_STAR",
"L_X",
"NU_X_THRESH",
"X_RAY_SPEC_INDEX",
)
if astro_params is not None:
if isinstance(astro_params, p21.AstroParams):
self.astro_params = astro_params
else:
self.astro_params = p21.AstroParams(astro_params)
else:
self.astro_params = p21.AstroParams()

self.cosmo_params = p21.CosmoParams(properties.COSMO_PARAMS)
self.flag_options = p21.FlagOptions(properties.FLAG_OPTIONS)
self.user_params = p21.UserParams(properties.USER_PARAMS)
self.global_params = global_params or {}
self.io_options = {
"store": store, # which summaries to store
"cache_dir": cache_dir, # where the stored data will be written
}

self.emulator = Emulator(version=version)

def _update_params(self, params):
"""
Update all the parameter structures which get passed to the driver.
Parameters
----------
params :
Parameter object from cosmoHammer
"""
ap_dict = copy.copy(self.astro_params.self)

ap_dict.update(
**{
k: getattr(params, k)
for k, v in params.items()
if k in self.astro_params.defining_dict
}
)

return p21.AstroParams(**ap_dict)

def build_model_data(self, ctx):
"""Compute all data defined by this core and add it to the context."""
# Update parameters
logger.debug(f"Updating parameters: {ctx.getParams()}")
astro_params = self._update_params(ctx.getParams())
logger.debug(f"AstroParams: {astro_params}")
# Take only needed AstroParams
input_dict = {k: getattr(astro_params, k) for k in self.astro_param_keys}

# Call 21cmEMU wrapper which returns a dict
theta, outputs, errors = self.emulator.predict(astro_params=input_dict)
if self.io_options["cache_dir"] is not None:
par_vals = ["{:0.3e}".format(i) for i in list(input_dict.values())]
name = "_".join(par_vals)
outputs.write(
fname=self.io_options["cache_dir"] + name,
theta=theta,
store=self.io_options["store"],
)
logger.debug(f"Adding {self.ctx_variables} to context data")
for key in self.ctx_variables:
try:
ctx.add(key + self.name, getattr(outputs, key))
except AttributeError:
try:
ctx.add(key + self.name, errors[key])
except:
raise ValueError(
f"ctx_variable {key} not an attribute of EmulatorOutput or errors dict."
)
1 change: 0 additions & 1 deletion src/py21cmmc/cosmoHammer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ def get_last_sample(self):
return tuple(last)

def _check(self, coords, log_prob, blobs, accepted):

self._check_blobs(blobs[0])
nwalkers, ndim = self.shape

Expand Down
Binary file added src/py21cmmc/data/HERA_H1C_IDR3.npz
Binary file not shown.
Loading

0 comments on commit 45805ab

Please sign in to comment.