Skip to content

Commit

Permalink
Merge pull request #49 from aldro61/custom_tiebreakers
Browse files Browse the repository at this point in the history
Support for user-specified tiebreakers
  • Loading branch information
aldro61 authored Mar 29, 2018
2 parents e950695 + 0035ace commit 7367be2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
22 changes: 22 additions & 0 deletions examples/tiebreaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
An example showing how to use a custom tiebreaker function.
"""
import numpy as np


from pyscm.scm import SetCoveringMachineClassifier
from sklearn.datasets import make_classification

n_examples = 200
n_features = 1000

X,y = make_classification(n_samples=n_examples, n_features=n_features, n_classes=2,
random_state=np.random.RandomState(42))

def my_tiebreaker(model_type, feature_idx, thresholds, kind):
print("Hello from the tiebreaker! Got {0:d} equivalent rules for this {1!s} model.".format(len(feature_idx), model_type))
return 0

clf = SetCoveringMachineClassifier()
clf.fit(X, y, tiebreaker=my_tiebreaker)
19 changes: 15 additions & 4 deletions pyscm/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def set_params(self, **parameters):
setattr(self, parameter, value)
return self

def fit(self, X, y, iteration_callback=None, **fit_params):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
"""
Fit a SCM model.
Expand All @@ -59,6 +59,15 @@ def fit(self, X, y, iteration_callback=None, **fit_params):
The feature of the input examples.
y : array-like, shape = [n_samples]
The labels of the input examples.
tiebreaker: function(model_type, feature_idx, thresholds, rule_type)
A function that takes in the model type and information about the
equivalent rules and outputs the index of the rule to use. The lists
respectively contain the feature indices, thresholds and type
corresponding of the equivalent rules. If None, the rule that most
decreases the training error is selected. Note: the model type is
provided because the rules that are added to disjunction models
correspond to the inverse of the rules that are handled during
training. Handle this case with care.
iteration_callback: function(model)
A function that is called each time a rule is added to the model.
Expand Down Expand Up @@ -128,11 +137,13 @@ def fit(self, X, y, iteration_callback=None, **fit_params):
opti_P_bar = self._get_best_utility_rules(X, y, X_argsort_by_feature_T, remaining_example_idx.copy(),
**utility_function_additional_args)

# TODO: Support user specified tiebreaker
logging.debug("Tiebreaking. Found {0:d} optimal rules".format(len(opti_feat_idx)))
if len(opti_feat_idx) > 1:
trainig_risk_decrease = 1.0 * opti_N - opti_P_bar
keep_idx = np.where(trainig_risk_decrease == trainig_risk_decrease.max())[0][0]
if tiebreaker is None:
training_risk_decrease = 1.0 * opti_N - opti_P_bar
keep_idx = np.where(training_risk_decrease == training_risk_decrease.max())[0][0]
else:
keep_idx = tiebreaker(self.model_type, opti_feat_idx, opti_threshold, opti_kind)
else:
keep_idx = 0
stump = DecisionStump(feature_idx=opti_feat_idx[keep_idx], threshold=opti_threshold[keep_idx],
Expand Down

0 comments on commit 7367be2

Please sign in to comment.