diff --git a/example/InjectionRecovery.py b/example/InjectionRecovery.py index 22ec3c20..a390b76d 100644 --- a/example/InjectionRecovery.py +++ b/example/InjectionRecovery.py @@ -20,7 +20,7 @@ class InjectionRecoveryParser(Tap): # Noise parameters seed: int = 0 - f_sampling: int = 2048 + f_sampling: int = 4096 duration: int = 4 fmin: float = 20.0 ifos: list[str] = ["H1", "L1", "V1"] @@ -28,36 +28,40 @@ class InjectionRecoveryParser(Tap): # Injection parameters m1: float = 30.0 m2: float = 25.0 - s1_theta: float = 0. - s1_phi: float = 0. - s1_mag: float = 0.01 - s2_theta: float = 0. - s2_phi: float = 0. - s2_mag: float = 0.02 + s1_theta: float = 0.04 + s1_phi: float = 0.02 + s1_mag: float = 0.1 + s2_theta: float = 0.01 + s2_phi: float = 0.03 + s2_mag: float = 0.05 dist_mpc: float = 400. tc: float = 0. - phic: float = 0. - inclination: float = 0.3 + phic: float = 0.1 + inclination: float = 0.5 polarization_angle: float = 0.7 - ra: float = 1.1 + ra: float = 1.2 dec: float = 0.3 # Sampler parameters n_dim: int = 15 n_chains: int = 500 - n_loop_training: int = 10 + n_loop_training: int = 200 n_loop_production: int = 10 n_local_steps: int = 300 n_global_steps: int = 300 learning_rate: float = 0.001 - max_samples: int = 50000 + max_samples: int = 60000 momentum: float = 0.9 - num_epochs: int = 500 - batch_size: int = 50000 + num_epochs: int = 300 + batch_size: int = 60000 stepsize: float = 0.01 use_global: bool = True - keep_quantile: float = 0.1 - train_thinning: int = 10 + keep_quantile: float = 0.0 + train_thinning: int = 1 + output_thinning: int = 30 + num_layers: int = 6 + hidden_size: list[int] = [64,64] + num_bins: int = 8 # Output parameters output_path: str = "./" @@ -75,7 +79,7 @@ class InjectionRecoveryParser(Tap): print("Injection signals") -freqs = jnp.linspace(args.fmin, args.f_sampling/2, args.duration*args.f_sampling//2) +freqs = jnp.linspace(args.fmin, args.f_sampling/2, args.duration*args.f_sampling) Mc, eta = ms_to_Mc_eta(jnp.array([args.m1, args.m2])) f_ref = args.fmin @@ -133,8 +137,12 @@ class InjectionRecoveryParser(Tap): use_global=args.use_global, keep_quantile= args.keep_quantile, train_thinning = args.train_thinning, + output_thinning = args.output_thinning, local_sampler_arg = local_sampler_arg, seed = args.seed, + num_layers = args.num_layers, + hidden_size = args.hidden_size, + num_bins = args.num_bins ) key, subkey = jax.random.split(key)