Skip to content

Commit

Permalink
Merge pull request #27 from kazewong/10-refactoring-jim-for-an-easier…
Browse files Browse the repository at this point in the history
…-adoption-for-production-in-lvk

10 refactoring jim for an easier adoption for production in lvk
  • Loading branch information
kazewong authored Sep 18, 2023
2 parents 741f24f + d691798 commit 844fd3b
Show file tree
Hide file tree
Showing 56 changed files with 1,649 additions and 2,805 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,8 @@ data

slurm_script*
build*
log*
log*
*.swp
H1.txt
L1.txt
V1.txt
1 change: 1 addition & 0 deletions docs/examples
31 changes: 31 additions & 0 deletions docs/gen_ref_pages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Generate the code reference pages."""

from pathlib import Path

import mkdocs_gen_files

nav = mkdocs_gen_files.Nav()


for path in sorted(Path("src").rglob("*.py")): #
module_path = path.relative_to("src").with_suffix("") #
doc_path = path.relative_to("src").with_suffix(".md") #
full_doc_path = Path("reference", doc_path) #

parts = list(module_path.parts)

if parts[-1] == "__init__": #
parts = parts[:-1]
elif parts[-1] == "__main__":
continue

nav[parts] = doc_path.as_posix()

with mkdocs_gen_files.open(full_doc_path, "w") as fd: #
identifier = ".".join(parts) #
print("::: " + identifier, file=fd) #

mkdocs_gen_files.set_edit_path(full_doc_path, path) #

with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: #
nav_file.writelines(nav.build_literate_nav())
Empty file added docs/gotchas.md
Empty file.
17 changes: 17 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Welcome to MkDocs

For full documentation visit [mkdocs.org](https://www.mkdocs.org).

## Commands

* `mkdocs new [dir-name]` - Create a new project.
* `mkdocs serve` - Start the live-reloading docs server.
* `mkdocs build` - Build the documentation site.
* `mkdocs -h` - Print help message and exit.

## Project layout

mkdocs.yml # The configuration file.
docs/
index.md # The documentation homepage.
... # Other markdown pages, images and other files.
Empty file added docs/jax.md
Empty file.
7 changes: 7 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mkdocs==1.4.3 # Main documentation generator.
mkdocs-material==9.1.18 # Theme
pymdown-extensions==10.1 # Markdown extensions e.g. to handle LaTeX.
mkdocstrings[python]==0.22.0 # Autogenerate documentation from docstrings.
mkdocs-jupyter==0.24.2 # Turn Jupyter Lab notebooks into webpages.
mkdocs-gen-files==0.5.0
mkdocs-literate-nav=0.6.0
Empty file.
17 changes: 0 additions & 17 deletions example/DataProcessing.py

This file was deleted.

80 changes: 80 additions & 0 deletions example/GW150914.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import time
from jimgw.jim import Jim
from jimgw.detector import H1, L1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomD
from jimgw.prior import Uniform
import jax.numpy as jnp
import jax

jax.config.update("jax_enable_x64", True)

###########################################
########## First we grab data #############
###########################################

total_time_start = time.time()

# first, fetch a 4s segment centered on GW150914
gps = 1126259462.4
start = gps - 2
end = gps + 2
fmin = 20.0
fmax = 1024.0

ifos = ["H1", "L1"]

H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

prior = Uniform(
xmin=[10, 0.125, -1.0, -1.0, 0.0, -0.05, 0.0, -1, 0.0, 0.0, -1.0],
xmax=[80.0, 1.0, 1.0, 1.0, 2000.0, 0.05, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0],
naming=[
"M_c",
"q",
"s1_z",
"s2_z",
"d_L",
"t_c",
"phase_c",
"cos_iota",
"psi",
"ra",
"sin_dec",
],
transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2),
"cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)),
"sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec
)
likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2)
# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2)


mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

jim = Jim(
likelihood,
prior,
n_loop_training=200,
n_loop_production=10,
n_local_steps=150,
n_global_steps=150,
n_chains=500,
n_epochs=50,
learning_rate=0.001,
max_samples=45000,
momentum=0.9,
batch_size=50000,
use_global=True,
keep_quantile=0.0,
train_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
)

jim.maximize_likelihood([prior.xmin, prior.xmax])
jim.sample(jax.random.PRNGKey(42))
79 changes: 79 additions & 0 deletions example/GW150914_PV2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import time
from jimgw.jim import Jim
from jimgw.detector import H1, L1
from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD
from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2
from jimgw.prior import Uniform
import jax.numpy as jnp
import jax

jax.config.update("jax_enable_x64", True)

###########################################
########## First we grab data #############
###########################################

total_time_start = time.time()

# first, fetch a 4s segment centered on GW150914
gps = 1126259462.4
start = gps - 2
end = gps + 2
fmin = 20.0
fmax = 1024.0

ifos = ["H1", "L1"]

H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

waveform = RippleIMRPhenomPv2(f_ref=20)
prior = Uniform(
xmin = [10, 0.125, 0, 0, 0, 0, 0, 0, 0., -0.05, 0., -1, 0., 0.,-1.],
xmax = [80., 1., jnp.pi, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1., 2000., 0.05, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1.],
naming = ["M_c", "q", "s1_theta", "s1_phi", "s1_mag", "s2_theta", "s2_phi", "s2_mag", "d_L", "t_c", "phase_c", "cos_iota", "psi", "ra", "sin_dec"],
transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2),
"s1_theta": ("s1_x", lambda params: jnp.sin(params['s1_theta'])*jnp.cos(params['s1_phi'])*params['s1_mag']),
"s1_phi": ("s1_y", lambda params: jnp.sin(params['s1_theta'])*jnp.sin(params['s1_phi'])*params['s1_mag']),
"s1_mag": ("s1_z", lambda params: jnp.cos(params['s1_theta'])*params['s1_mag']),
"s2_theta": ("s2_x", lambda params: jnp.sin(params['s2_theta'])*jnp.cos(params['s2_phi'])*params['s2_mag']),
"s2_phi": ("s2_y", lambda params: jnp.sin(params['s2_theta'])*jnp.sin(params['s2_phi'])*params['s2_mag']),
"s2_mag": ("s2_z", lambda params: jnp.cos(params['s2_theta'])*params['s2_mag']),
"cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)),
"sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec
)
likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)
# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2)


mass_matrix = jnp.eye(prior.n_dim)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}

jim = Jim(
likelihood,
prior,
n_loop_training=400,
n_loop_production=10,
n_local_steps=300,
n_global_steps=300,
n_chains=500,
n_epochs=300,
learning_rate=0.001,
max_samples = 60000,
momentum=0.9,
batch_size=30000,
use_global=True,
keep_quantile=0.,
train_thinning=1,
output_thinning=30,
local_sampler_arg=local_sampler_arg,
num_layers = 4,
hidden_size = [32,32],
num_bins = 8
)

jim.maximize_likelihood([prior.xmin, prior.xmax])
# initial_guess = jnp.array(jnp.load('initial.npz')['chain'])
jim.sample(jax.random.PRNGKey(42))
Loading

0 comments on commit 844fd3b

Please sign in to comment.