diff --git a/pyext/src/restraints/parameters.py b/pyext/src/restraints/parameters.py index c085d563..fed1d753 100644 --- a/pyext/src/restraints/parameters.py +++ b/pyext/src/restraints/parameters.py @@ -6,64 +6,30 @@ import IMP.container import IMP.isd import IMP.pmi.tools +import IMP.pmi.restraints -class WeightRestraint(object): - def __init__(self, weight, lower, upper, kappa): +class WeightRestraint(IMP.pmi.restraints.RestraintBase): - self.weight = weight - self.m = self.weight.get_model() - self.label = "None" - self.rs = IMP.RestraintSet(self.m, 'weight_restraint') + def __init__(self, w, lower, upper, kappa, label=None, weight=1.): + self.w = w + m = self.w.get_model() + super(WeightRestraint, self).__init__(m, label=label, weight=weight) self.lower = lower self.upper = upper self.kappa = kappa self.rs.add_restraint( IMP.isd.WeightRestraint( - self.weight, + self.w, self.lower, self.upper, self.kappa)) - def get_restraint(self, label): - return self.rs - def add_to_model(self): - IMP.pmi.tools.add_restraint_to_model(self.m, self.rs) +class JeffreysPrior(IMP.pmi.restraints.RestraintBase): - def set_label(self, label): - self.label = label - - def get_output(self): - self.m.update() - output = {} - score = self.rs.unprotected_evaluate(None) - output["_TotalScore"] = str(score) - output["WeightRestraint_" + self.label] = str(score) - return output - - -class JeffreysPrior(object): - - def __init__(self, nuisance): - - self.m = nuisance.get_model() - self.label = "None" - self.rs = IMP.RestraintSet(self.m, 'jeffrey_prior') + def __init__(self, nuisance, label=None, weight=1.): + m = nuisance.get_model() + super(JeffreysPrior, self).__init__(m, label=label, weight=weight) jp = IMP.isd.JeffreysRestraint(self.m, nuisance) self.rs.add_restraint(jp) - - def add_to_model(self): - IMP.pmi.tools.add_restraint_to_model(self.m, self.rs) - - def set_label(self, label): - self.label = label - - def get_output(self): - output = {} - score = self.rs.unprotected_evaluate(None) - output["_TotalScore"] = str(score) - output["JeffreyPrior_" + self.label] = str(score) - return output - -#