diff --git a/sklearn/fml/constants.py b/sklearn/fml/constants.py index 1aa1bd3e7639c..fb98faf6b8bad 100644 --- a/sklearn/fml/constants.py +++ b/sklearn/fml/constants.py @@ -1,5 +1,6 @@ class URI: - _SERVER = 'https://fmlearn.herokuapp.com/' + _SERVER = 'https://fmlearn.herokuapp.com' + _LOCAL = 'http://127.0.0.1:5000' _METRIC = '/metric' _RETRIEVE = '/retrieve' @@ -7,6 +8,10 @@ class URI: _MIN = '/min' _ALL = '/all' + def __init__(self, debug=False): + if debug: + self._SERVER = self._LOCAL + def post_metric(self): return self._SERVER + self._METRIC diff --git a/sklearn/fml/fml_client.py b/sklearn/fml/fml_client.py index 282acf686a8ff..44028a802ad01 100644 --- a/sklearn/fml/fml_client.py +++ b/sklearn/fml/fml_client.py @@ -22,7 +22,7 @@ def _post_msg(self, uri, data): print(res.status_code) return res.json() - def publish(self, model, metric_name, metric_value, dataset): + def publish(self, model, metric_name, metric_value, dataset, params=None): """ Publishes the data collected to the federated meta learning API """ @@ -35,6 +35,16 @@ def publish(self, model, metric_name, metric_value, dataset): data['metric_name'] = metric_name data['metric_value'] = metric_value data['dataset_hash'] = dataset_hash + if params != None: + model_params = [] + for key, value in params.items(): + new_param = {} + new_param['param_name'] = str(key) + new_param['param_value'] = str(value) + model_params.append(new_param) + data['params'] = model_params + else: + data['params'] = "" return self._post_msg(URI().post_metric(), data)