Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add radio background emulator #219

Merged
merged 47 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
1a13703
feat: Add radio background emulator
DanielaBreitman Jul 16, 2024
d3efecd
fix: properties
DanielaBreitman Jul 16, 2024
00a4b4c
test: fix old tests
DanielaBreitman Jul 19, 2024
5c667a3
test: fix old tests
DanielaBreitman Jul 19, 2024
745706a
test: fix old tests
DanielaBreitman Jul 22, 2024
d133b7b
test: fix old tests
DanielaBreitman Jul 22, 2024
4359f65
test: fix old tests
DanielaBreitman Jul 22, 2024
7e48a08
test: fix old tests
DanielaBreitman Jul 22, 2024
6761255
test: fix old tests
DanielaBreitman Jul 22, 2024
8c34115
test: fix old tests
DanielaBreitman Jul 22, 2024
389a532
test: fix old tests
DanielaBreitman Jul 22, 2024
e9e14c5
test: fix old tests
DanielaBreitman Jul 22, 2024
8e84d98
test: fix old tests
DanielaBreitman Jul 22, 2024
bd7466c
test: fix old tests
DanielaBreitman Jul 22, 2024
6001f11
test: fix old tests
DanielaBreitman Jul 22, 2024
4db0545
test: fix old tests
DanielaBreitman Jul 22, 2024
a75616b
test: adding new tests
DanielaBreitman Jul 22, 2024
16e9af3
test: adding new tests
DanielaBreitman Jul 22, 2024
e83c6d1
test: adding new tests
DanielaBreitman Jul 22, 2024
ba1d10b
test: adding new tests
DanielaBreitman Jul 22, 2024
d353e59
test: adding new tests
DanielaBreitman Jul 22, 2024
35807d0
test: adding new tests
DanielaBreitman Jul 22, 2024
0a862e1
test: adding new tests
DanielaBreitman Jul 22, 2024
5972745
test: adding new tests
DanielaBreitman Jul 22, 2024
3999c6b
test: adding new tests
DanielaBreitman Jul 22, 2024
5694aab
test: adding new tests
DanielaBreitman Jul 22, 2024
18f2200
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
48928b1
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
dbbf2e0
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
3b124c5
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
e91990b
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
dbfa2b8
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
420252b
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
1fd6a85
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
d8d061b
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
9ad672d
test: excluding model.py from testing
DanielaBreitman Jul 23, 2024
adfa10c
docs: updating tutorials
DanielaBreitman Jul 24, 2024
8f07b65
test: excluding model.py from testing
DanielaBreitman Jul 24, 2024
0ae16a8
test: excluding model.py from testing
DanielaBreitman Jul 24, 2024
fd41f91
test: excluding model.py from testing
DanielaBreitman Jul 24, 2024
e2c959c
test: excluding model.py from testing
DanielaBreitman Jul 24, 2024
eab7fba
test: excluding model.py from testing
DanielaBreitman Jul 24, 2024
417ab88
test: adding tests
DanielaBreitman Jul 24, 2024
2d0bc54
test: adding tests
DanielaBreitman Jul 24, 2024
4e9f3d5
test: adding tests
DanielaBreitman Jul 24, 2024
8f2f81b
docs:Revert to old tutorial
DanielaBreitman Jul 24, 2024
437fc5e
docs:Revert to old tutorial
DanielaBreitman Jul 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

[paths]
source =
src/
*/site-packages/

omit =
*/models/radio_background/model.py
tests/
*/tmp/

[run]
branch = True
source = py21cmemu
omit =
*/models/radio_background/model.py
tests/
*/tmp/

[report]
# Regexes for lines to exclude from consideration
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover

# Don't complain about missing debug-only code:
def __repr__
if self\.debug

# Don't complain if tests don't hit defensive assertion code:
raise AssertionError
raise NotImplementedError

# Don't complain if non-runnable code isn't run:
if 0:
if __name__ == .__main__.:
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ sphinx>=5.2.0
sphinx-click>=4.4.0
tensorflow>=2.6.0
toml==0.10.2
torch>=1.9.0
1 change: 1 addition & 0 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ The following introductory tutorials will help you get started with ``21cmEMU``.
tutorials/basic_usage
tutorials/21cmFAST_tau_UVLFs
tutorials/21cmFAST_lightcone
tutorials/radio_emulator
Binary file added docs/tutorials/Radio_Test_data_sample.npz
Binary file not shown.
Binary file removed docs/tutorials/lightcone_example.npz
Binary file not shown.
815 changes: 815 additions & 0 deletions docs/tutorials/radio_emulator.ipynb

Large diffs are not rendered by default.

315 changes: 313 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ click = ">=8.0.1"
numpy = "^1.22.0"
scipy = "^1.10.1"
tensorflow = ">=2.4.0, <= 2.14.0"
torch = ">=1.9.0"
appdirs = "^1.4.4"
toml = "^0.10.2"
GitPython = "^3.1.31"
Expand Down
7 changes: 3 additions & 4 deletions src/py21cmemu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
__version__ = "1.0.8"
from .emulator import Emulator
from .get_emulator import get_emu_data
from .inputs import EmulatorInput
from .inputs import DefaultEmulatorInput
from .inputs import RadioEmulatorInput
from .outputs import EmulatorOutput
from .outputs import RawEmulatorOutput
from .properties import COSMO_PARAMS
from .properties import FLAG_OPTIONS
from .properties import USER_PARAMS
from .properties import EmulatorProperties
from .properties import emulator_properties
101 changes: 82 additions & 19 deletions src/py21cmemu/emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

import numpy as np
import tensorflow as tf

from .config import CONFIG
from .get_emulator import get_emu_data
from .inputs import EmulatorInput
from .inputs import DefaultEmulatorInput
from .inputs import ParamVecType
from .inputs import RadioEmulatorInput
from .outputs import DefaultRawEmulatorOutput
from .outputs import EmulatorOutput
from .outputs import RawEmulatorOutput
from .outputs import RadioRawEmulatorOutput
from .properties import emulator_properties


Expand All @@ -27,16 +29,60 @@ class Emulator:
----------
version : str, optional
Emulator version to use/download, default is 'latest'.
emulator : str, optional
Emulator to use. Options are: 'radio_background' and 'default'.
The radio background emulator is the emulator used in Cang+24
It is a model that predicts the radio background
temperature :math:`T_{\rm r} \rm{[K]}`,
the global IGM neutral fraction :math:`\overline{x}_{\rm HI}`,
the global 21-cm brightness temperature :math:`T{\rm b} \rm{[mK]}`,
the 21-cm spherically-averaged power spectrum :math:`P(k) \rm{[mK^2]}`, and
the Thomson scattering optical depth :math:`\tau`.
It has five input parameters:
["fR_mini", "L_X_MINI", "F_STAR7_MINI", "F_ESC7_MINI", "A_LW"]
See 21cmFAST documentation for more information about the input parameters.

The default emulator is the emulator described in Breitman+23.
It emulates six summary statistics with 9 input astrophysical parameters.
"""

def __init__(self, version: str = "latest"):
get_emu_data(version=version)
def __init__(self, emulator: str = "default", version: str = "latest"):

emu = tf.keras.models.load_model(CONFIG.emu_path, compile=False)
self.which_emulator = emulator
if self.which_emulator == "default":
import tensorflow as tf

self.model = emu
self.inputs = EmulatorInput()
self.properties = emulator_properties
get_emu_data(version=version)
model = tf.keras.models.load_model(CONFIG.emu_path, compile=False)
self.inputs = DefaultEmulatorInput()

elif self.which_emulator == "radio_background":
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
from .models.radio_background.model import Radio_Emulator

here = Path(__file__).parent
model = Radio_Emulator()
model.load_state_dict(
torch.load(
here / "models/radio_background/Radio_Background_Emu_Weights",
map_location=device,
),
)
model.eval()
self.inputs = RadioEmulatorInput()

else:
raise ValueError(
"Please supply one of the following emulator names:"
+ "'default' or 'radio_background'. "
+ f"{emulator} is not a valid emulator name."
)

self.model = model
self.properties = emulator_properties(emulator=emulator)

def __getattr__(self, name: str) -> Any:
"""Allow access to emulator properties directly from the emulator object."""
Expand Down Expand Up @@ -67,7 +113,15 @@ def predict(
The mean error on the test set (i.e. independent of theta).
"""
theta = self.inputs.make_param_array(astro_params, normed=True)
emu = RawEmulatorOutput(self.model.predict(theta, verbose=verbose))
if self.which_emulator == "default":
emu = DefaultRawEmulatorOutput(self.model.predict(theta, verbose=verbose))
if self.which_emulator == "radio_background":
import torch

emu = RadioRawEmulatorOutput(
self.model(torch.Tensor(theta)).detach().cpu().numpy()
)

emu = emu.get_renormalized()

errors = self.get_errors(emu, theta)
Expand All @@ -94,12 +148,21 @@ def get_errors(
# For now, we return the mean emulator error (obtained from the test set) for
# each summary. All errors are the median absolute difference between test set
# and prediction AFTER units have been restored AND log has been removed.
return {
"PS_err": self.PS_err,
"Tb_err": self.Tb_err,
"xHI_err": self.xHI_err,
"Ts_err": self.Ts_err,
"UVLFs_err": self.UVLFs_err,
"UVLFs_logerr": self.UVLFs_logerr,
"tau_err": self.tau_err,
}
if self.which_emulator == "default":
return {
"PS_err": self.PS_err,
"Tb_err": self.Tb_err,
"xHI_err": self.xHI_err,
"Ts_err": self.Ts_err,
"UVLFs_err": self.UVLFs_err,
"UVLFs_logerr": self.UVLFs_logerr,
"tau_err": self.tau_err,
}
else:
return {
"PS_err": self.PS_err,
"Tb_err": self.Tb_err,
"xHI_err": self.xHI_err,
"Tr_err": self.Tr_err,
"tau_err": self.tau_err,
}
94 changes: 78 additions & 16 deletions src/py21cmemu/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from .properties import emulator_properties as properties
from .properties import emulator_properties


SingleParamVecType = Union[Dict[str, float], np.ndarray, Sequence[float]]
Expand All @@ -18,17 +18,8 @@
class EmulatorInput:
"""Class for handling emulator inputs."""

astro_param_keys = (
"F_STAR10",
"ALPHA_STAR",
"F_ESC10",
"ALPHA_ESC",
"M_TURN",
"t_STAR",
"L_X",
"NU_X_THRESH",
"X_RAY_SPEC_INDEX",
)
def __init__(self, emulator: str = "default"):
self.properties = emulator_properties(emulator=emulator)

def _format_single_theta_vector(self, theta: SingleParamVecType) -> np.ndarray:
if len(theta) != len(self.astro_param_keys):
Expand Down Expand Up @@ -113,6 +104,25 @@ def make_list_of_dicts(
theta = self.make_param_array(theta, normed=normed)
return [dict(zip(self.astro_param_keys, theta[i])) for i in range(len(theta))]


class DefaultEmulatorInput(EmulatorInput):
"""Class for handling emulator inputs."""

def __init__(self):
"""Class for handling emulator inputs."""
self.astro_param_keys = (
"F_STAR10",
"ALPHA_STAR",
"F_ESC10",
"ALPHA_ESC",
"M_TURN",
"t_STAR",
"L_X",
"NU_X_THRESH",
"X_RAY_SPEC_INDEX",
)
super().__init__(emulator="default")

def normalize(self, theta: np.ndarray) -> np.ndarray:
"""Normalize the parameters.

Expand All @@ -129,8 +139,8 @@ def normalize(self, theta: np.ndarray) -> np.ndarray:
"""
theta_woutdims = theta.copy()
theta_woutdims[:, 7] /= 1000
theta_woutdims -= properties.limits[:, 0]
theta_woutdims /= properties.limits[:, 1] - properties.limits[:, 0]
theta_woutdims -= self.properties.limits[:, 0]
theta_woutdims /= self.properties.limits[:, 1] - self.properties.limits[:, 0]
return theta_woutdims

def undo_normalization(self, theta: np.ndarray) -> np.ndarray:
Expand All @@ -148,7 +158,59 @@ def undo_normalization(self, theta: np.ndarray) -> np.ndarray:
Un-normalized parameters, with shape (n_batch, n_params).
"""
theta_wdims = theta.copy()
theta_wdims *= properties.limits[:, 1] - properties.limits[:, 0]
theta_wdims += properties.limits[:, 0]
theta_wdims *= self.properties.limits[:, 1] - self.properties.limits[:, 0]
theta_wdims += self.properties.limits[:, 0]
theta_wdims[:, 7] *= 1000
return theta_wdims


class RadioEmulatorInput(EmulatorInput):
"""Class for handling radio background emulator inputs."""

def __init__(self):
self.astro_param_keys = (
"fR_mini",
"L_X_MINI",
"F_STAR7_MINI",
"F_ESC7_MINI",
"A_LW",
)
super().__init__(emulator="radio_background")

def normalize(self, theta: np.ndarray) -> np.ndarray:
"""Normalize the parameters.

Parameters
----------
theta : np.ndarray
Input parameters, strictly in 2D array format, with shape
(n_batch, n_params).

Returns
-------
np.ndarray
Normalized parameters, with shape (n_batch, n_params).
"""
theta_woutdims = theta.copy()
theta_woutdims -= self.properties.limits[:, 0]
theta_woutdims /= self.properties.limits[:, 1] - self.properties.limits[:, 0]
return theta_woutdims

def undo_normalization(self, theta: np.ndarray) -> np.ndarray:
"""Undo the normalization of the parameters.

Parameters
----------
theta : np.ndarray
Input parameters, strictly in 2D array format, with shape
(n_batch, n_params).

Returns
-------
np.ndarray
Un-normalized parameters, with shape (n_batch, n_params).
"""
theta_wdims = theta.copy()
theta_wdims *= self.properties.limits[:, 1] - self.properties.limits[:, 0]
theta_wdims += self.properties.limits[:, 0]
return theta_wdims
Binary file not shown.
Binary file not shown.
Loading
Loading