Skip to content

Commit

Permalink
Smaller network seems to improve quality quite a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Sep 16, 2023
1 parent 76d2cfb commit 4c111e8
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions example/InjectionRecovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,48 @@ class InjectionRecoveryParser(Tap):

# Noise parameters
seed: int = 0
f_sampling: int = 2048
f_sampling: int = 4096
duration: int = 4
fmin: float = 20.0
ifos: list[str] = ["H1", "L1", "V1"]

# Injection parameters
m1: float = 30.0
m2: float = 25.0
s1_theta: float = 0.
s1_phi: float = 0.
s1_mag: float = 0.01
s2_theta: float = 0.
s2_phi: float = 0.
s2_mag: float = 0.02
s1_theta: float = 0.04
s1_phi: float = 0.02
s1_mag: float = 0.1
s2_theta: float = 0.01
s2_phi: float = 0.03
s2_mag: float = 0.05
dist_mpc: float = 400.
tc: float = 0.
phic: float = 0.
inclination: float = 0.3
phic: float = 0.1
inclination: float = 0.5
polarization_angle: float = 0.7
ra: float = 1.1
ra: float = 1.2
dec: float = 0.3

# Sampler parameters
n_dim: int = 15
n_chains: int = 500
n_loop_training: int = 10
n_loop_training: int = 200
n_loop_production: int = 10
n_local_steps: int = 300
n_global_steps: int = 300
learning_rate: float = 0.001
max_samples: int = 50000
max_samples: int = 60000
momentum: float = 0.9
num_epochs: int = 500
batch_size: int = 50000
num_epochs: int = 300
batch_size: int = 60000
stepsize: float = 0.01
use_global: bool = True
keep_quantile: float = 0.1
train_thinning: int = 10
keep_quantile: float = 0.0
train_thinning: int = 1
output_thinning: int = 30
num_layers: int = 6
hidden_size: list[int] = [64,64]
num_bins: int = 8

# Output parameters
output_path: str = "./"
Expand All @@ -75,7 +79,7 @@ class InjectionRecoveryParser(Tap):

print("Injection signals")

freqs = jnp.linspace(args.fmin, args.f_sampling/2, args.duration*args.f_sampling//2)
freqs = jnp.linspace(args.fmin, args.f_sampling/2, args.duration*args.f_sampling)

Mc, eta = ms_to_Mc_eta(jnp.array([args.m1, args.m2]))
f_ref = args.fmin
Expand Down Expand Up @@ -133,8 +137,12 @@ class InjectionRecoveryParser(Tap):
use_global=args.use_global,
keep_quantile= args.keep_quantile,
train_thinning = args.train_thinning,
output_thinning = args.output_thinning,
local_sampler_arg = local_sampler_arg,
seed = args.seed,
num_layers = args.num_layers,
hidden_size = args.hidden_size,
num_bins = args.num_bins
)

key, subkey = jax.random.split(key)
Expand Down

0 comments on commit 4c111e8

Please sign in to comment.