diff --git a/.flake8 b/.flake8 index e714a89..16b18d0 100644 --- a/.flake8 +++ b/.flake8 @@ -2,6 +2,8 @@ max-line-length = 120 max-complexity = 45 ignore = + C901, + # function is too complex E203, # missing whitespace around operator E225, diff --git a/enterprise_extensions/sampler.py b/enterprise_extensions/sampler.py index bce1ced..6fddf8d 100644 --- a/enterprise_extensions/sampler.py +++ b/enterprise_extensions/sampler.py @@ -166,6 +166,15 @@ def __init__(self, pta, snames=None, empirical_distr=None, f_stat_file=None, sav self.ndim = sum(p.size or 1 for p in pta.params) self.plist = [p.name for p in pta.params] + # parameter dictionary + self.params_dict = {} + for p in self.params: + if p.size: + for ii in range(0, p.size): + self.params_dict.update({p.name + "_{}".format(ii): p}) + else: + self.params_dict.update({p.name: p}) + # parameter map self.pmap = {} ct = 0 @@ -616,9 +625,7 @@ def draw_from_gwb_log_uniform_distribution(self, x, iter, beta): signal_name = [par for par in self.pnames if ('gw' in par and 'log10_A' in par)][0] - param_names = [par.name for par in self.params] - idx = list(param_names).index(signal_name) - param = self.params[idx] + param = self.params_dict[signal_name] q[self.pmap[str(param)]] = np.random.uniform(param.prior._defaults['pmin'], param.prior._defaults['pmax']) diff --git a/requirements.txt b/requirements.txt index 1e5167f..60263cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ scikit-sparse>=0.4.5 pint-pulsar>=0.8.2 libstempo>=2.4.0 enterprise-pulsar>=3.3.0 -scikit-learn==0.24 +scikit-learn==1.0.1 emcee ptmcmcsampler numdifftools