Skip to content

Commit

Permalink
Merge pull request #219 from 21cmfast/add_radio_background_emu
Browse files Browse the repository at this point in the history
feat: Add radio background emulator
  • Loading branch information
DanielaBreitman authored Jul 24, 2024
2 parents df2634d + 437fc5e commit 752bb43
Show file tree
Hide file tree
Showing 18 changed files with 2,018 additions and 163 deletions.
36 changes: 36 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

[paths]
source =
src/
*/site-packages/

omit =
*/models/radio_background/model.py
tests/
*/tmp/

[run]
branch = True
source = py21cmemu
omit =
*/models/radio_background/model.py
tests/
*/tmp/

[report]
# Regexes for lines to exclude from consideration
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover

# Don't complain about missing debug-only code:
def __repr__
if self\.debug

# Don't complain if tests don't hit defensive assertion code:
raise AssertionError
raise NotImplementedError

# Don't complain if non-runnable code isn't run:
if 0:
if __name__ == .__main__.:
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ sphinx>=5.2.0
sphinx-click>=4.4.0
tensorflow>=2.6.0
toml==0.10.2
torch>=1.9.0
1 change: 1 addition & 0 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ The following introductory tutorials will help you get started with ``21cmEMU``.
tutorials/basic_usage
tutorials/21cmFAST_tau_UVLFs
tutorials/21cmFAST_lightcone
tutorials/radio_emulator
Binary file added docs/tutorials/Radio_Test_data_sample.npz
Binary file not shown.
Binary file removed docs/tutorials/lightcone_example.npz
Binary file not shown.
815 changes: 815 additions & 0 deletions docs/tutorials/radio_emulator.ipynb

Large diffs are not rendered by default.

315 changes: 313 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ click = ">=8.0.1"
numpy = "^1.22.0"
scipy = "^1.10.1"
tensorflow = ">=2.4.0, <= 2.14.0"
torch = ">=1.9.0"
appdirs = "^1.4.4"
toml = "^0.10.2"
GitPython = "^3.1.31"
Expand Down
7 changes: 3 additions & 4 deletions src/py21cmemu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
__version__ = "1.0.8"
from .emulator import Emulator
from .get_emulator import get_emu_data
from .inputs import EmulatorInput
from .inputs import DefaultEmulatorInput
from .inputs import RadioEmulatorInput
from .outputs import EmulatorOutput
from .outputs import RawEmulatorOutput
from .properties import COSMO_PARAMS
from .properties import FLAG_OPTIONS
from .properties import USER_PARAMS
from .properties import EmulatorProperties
from .properties import emulator_properties
101 changes: 82 additions & 19 deletions src/py21cmemu/emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

import numpy as np
import tensorflow as tf

from .config import CONFIG
from .get_emulator import get_emu_data
from .inputs import EmulatorInput
from .inputs import DefaultEmulatorInput
from .inputs import ParamVecType
from .inputs import RadioEmulatorInput
from .outputs import DefaultRawEmulatorOutput
from .outputs import EmulatorOutput
from .outputs import RawEmulatorOutput
from .outputs import RadioRawEmulatorOutput
from .properties import emulator_properties


Expand All @@ -27,16 +29,60 @@ class Emulator:
----------
version : str, optional
Emulator version to use/download, default is 'latest'.
emulator : str, optional
Emulator to use. Options are: 'radio_background' and 'default'.
The radio background emulator is the emulator used in Cang+24
It is a model that predicts the radio background
temperature :math:`T_{\rm r} \rm{[K]}`,
the global IGM neutral fraction :math:`\overline{x}_{\rm HI}`,
the global 21-cm brightness temperature :math:`T{\rm b} \rm{[mK]}`,
the 21-cm spherically-averaged power spectrum :math:`P(k) \rm{[mK^2]}`, and
the Thomson scattering optical depth :math:`\tau`.
It has five input parameters:
["fR_mini", "L_X_MINI", "F_STAR7_MINI", "F_ESC7_MINI", "A_LW"]
See 21cmFAST documentation for more information about the input parameters.
The default emulator is the emulator described in Breitman+23.
It emulates six summary statistics with 9 input astrophysical parameters.
"""

def __init__(self, version: str = "latest"):
get_emu_data(version=version)
def __init__(self, emulator: str = "default", version: str = "latest"):

emu = tf.keras.models.load_model(CONFIG.emu_path, compile=False)
self.which_emulator = emulator
if self.which_emulator == "default":
import tensorflow as tf

self.model = emu
self.inputs = EmulatorInput()
self.properties = emulator_properties
get_emu_data(version=version)
model = tf.keras.models.load_model(CONFIG.emu_path, compile=False)
self.inputs = DefaultEmulatorInput()

elif self.which_emulator == "radio_background":
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
from .models.radio_background.model import Radio_Emulator

here = Path(__file__).parent
model = Radio_Emulator()
model.load_state_dict(
torch.load(
here / "models/radio_background/Radio_Background_Emu_Weights",
map_location=device,
),
)
model.eval()
self.inputs = RadioEmulatorInput()

else:
raise ValueError(
"Please supply one of the following emulator names:"
+ "'default' or 'radio_background'. "
+ f"{emulator} is not a valid emulator name."
)

self.model = model
self.properties = emulator_properties(emulator=emulator)

def __getattr__(self, name: str) -> Any:
"""Allow access to emulator properties directly from the emulator object."""
Expand Down Expand Up @@ -67,7 +113,15 @@ def predict(
The mean error on the test set (i.e. independent of theta).
"""
theta = self.inputs.make_param_array(astro_params, normed=True)
emu = RawEmulatorOutput(self.model.predict(theta, verbose=verbose))
if self.which_emulator == "default":
emu = DefaultRawEmulatorOutput(self.model.predict(theta, verbose=verbose))
if self.which_emulator == "radio_background":
import torch

emu = RadioRawEmulatorOutput(
self.model(torch.Tensor(theta)).detach().cpu().numpy()
)

emu = emu.get_renormalized()

errors = self.get_errors(emu, theta)
Expand All @@ -94,12 +148,21 @@ def get_errors(
# For now, we return the mean emulator error (obtained from the test set) for
# each summary. All errors are the median absolute difference between test set
# and prediction AFTER units have been restored AND log has been removed.
return {
"PS_err": self.PS_err,
"Tb_err": self.Tb_err,
"xHI_err": self.xHI_err,
"Ts_err": self.Ts_err,
"UVLFs_err": self.UVLFs_err,
"UVLFs_logerr": self.UVLFs_logerr,
"tau_err": self.tau_err,
}
if self.which_emulator == "default":
return {
"PS_err": self.PS_err,
"Tb_err": self.Tb_err,
"xHI_err": self.xHI_err,
"Ts_err": self.Ts_err,
"UVLFs_err": self.UVLFs_err,
"UVLFs_logerr": self.UVLFs_logerr,
"tau_err": self.tau_err,
}
else:
return {
"PS_err": self.PS_err,
"Tb_err": self.Tb_err,
"xHI_err": self.xHI_err,
"Tr_err": self.Tr_err,
"tau_err": self.tau_err,
}
94 changes: 78 additions & 16 deletions src/py21cmemu/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from .properties import emulator_properties as properties
from .properties import emulator_properties


SingleParamVecType = Union[Dict[str, float], np.ndarray, Sequence[float]]
Expand All @@ -18,17 +18,8 @@
class EmulatorInput:
"""Class for handling emulator inputs."""

astro_param_keys = (
"F_STAR10",
"ALPHA_STAR",
"F_ESC10",
"ALPHA_ESC",
"M_TURN",
"t_STAR",
"L_X",
"NU_X_THRESH",
"X_RAY_SPEC_INDEX",
)
def __init__(self, emulator: str = "default"):
self.properties = emulator_properties(emulator=emulator)

def _format_single_theta_vector(self, theta: SingleParamVecType) -> np.ndarray:
if len(theta) != len(self.astro_param_keys):
Expand Down Expand Up @@ -113,6 +104,25 @@ def make_list_of_dicts(
theta = self.make_param_array(theta, normed=normed)
return [dict(zip(self.astro_param_keys, theta[i])) for i in range(len(theta))]


class DefaultEmulatorInput(EmulatorInput):
"""Class for handling emulator inputs."""

def __init__(self):
"""Class for handling emulator inputs."""
self.astro_param_keys = (
"F_STAR10",
"ALPHA_STAR",
"F_ESC10",
"ALPHA_ESC",
"M_TURN",
"t_STAR",
"L_X",
"NU_X_THRESH",
"X_RAY_SPEC_INDEX",
)
super().__init__(emulator="default")

def normalize(self, theta: np.ndarray) -> np.ndarray:
"""Normalize the parameters.
Expand All @@ -129,8 +139,8 @@ def normalize(self, theta: np.ndarray) -> np.ndarray:
"""
theta_woutdims = theta.copy()
theta_woutdims[:, 7] /= 1000
theta_woutdims -= properties.limits[:, 0]
theta_woutdims /= properties.limits[:, 1] - properties.limits[:, 0]
theta_woutdims -= self.properties.limits[:, 0]
theta_woutdims /= self.properties.limits[:, 1] - self.properties.limits[:, 0]
return theta_woutdims

def undo_normalization(self, theta: np.ndarray) -> np.ndarray:
Expand All @@ -148,7 +158,59 @@ def undo_normalization(self, theta: np.ndarray) -> np.ndarray:
Un-normalized parameters, with shape (n_batch, n_params).
"""
theta_wdims = theta.copy()
theta_wdims *= properties.limits[:, 1] - properties.limits[:, 0]
theta_wdims += properties.limits[:, 0]
theta_wdims *= self.properties.limits[:, 1] - self.properties.limits[:, 0]
theta_wdims += self.properties.limits[:, 0]
theta_wdims[:, 7] *= 1000
return theta_wdims


class RadioEmulatorInput(EmulatorInput):
"""Class for handling radio background emulator inputs."""

def __init__(self):
self.astro_param_keys = (
"fR_mini",
"L_X_MINI",
"F_STAR7_MINI",
"F_ESC7_MINI",
"A_LW",
)
super().__init__(emulator="radio_background")

def normalize(self, theta: np.ndarray) -> np.ndarray:
"""Normalize the parameters.
Parameters
----------
theta : np.ndarray
Input parameters, strictly in 2D array format, with shape
(n_batch, n_params).
Returns
-------
np.ndarray
Normalized parameters, with shape (n_batch, n_params).
"""
theta_woutdims = theta.copy()
theta_woutdims -= self.properties.limits[:, 0]
theta_woutdims /= self.properties.limits[:, 1] - self.properties.limits[:, 0]
return theta_woutdims

def undo_normalization(self, theta: np.ndarray) -> np.ndarray:
"""Undo the normalization of the parameters.
Parameters
----------
theta : np.ndarray
Input parameters, strictly in 2D array format, with shape
(n_batch, n_params).
Returns
-------
np.ndarray
Un-normalized parameters, with shape (n_batch, n_params).
"""
theta_wdims = theta.copy()
theta_wdims *= self.properties.limits[:, 1] - self.properties.limits[:, 0]
theta_wdims += self.properties.limits[:, 0]
return theta_wdims
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 752bb43

Please sign in to comment.