Skip to content

Commit

Permalink
more formatting fixes
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 15, 2024
1 parent 54cdc5e commit 51a0afa
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 69 deletions.
14 changes: 5 additions & 9 deletions openfl/federated/data/loader_xgb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from math import ceil

import numpy as np
import xgboost as xgb
from math import ceil


class XGBoostDataLoader:
"""A class used to represent a Data Loader for XGBoost models.
Expand Down Expand Up @@ -155,18 +157,12 @@ def get_train_dmatrix(self):
Returns:
xgb.DMatrix: The DMatrix object for the training data.
"""
return {
'dmatrix': self.get_dmatrix(self.X_train, self.y_train),
'labels': self.y_train
}
return {"dmatrix": self.get_dmatrix(self.X_train, self.y_train), "labels": self.y_train}

def get_valid_dmatrix(self):
"""Returns the DMatrix for the validation data.
Returns:
xgb.DMatrix: The DMatrix object for the validation data.
"""
return {
'dmatrix': self.get_dmatrix(self.X_valid, self.y_valid),
'labels': self.y_valid
}
return {"dmatrix": self.get_dmatrix(self.X_valid, self.y_valid), "labels": self.y_valid}
120 changes: 71 additions & 49 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
# from copy import deepcopy
# from typing import Iterator, Tuple

import numpy as np
import json

import numpy as np
import xgboost as xgb
from sklearn.metrics import accuracy_score

from openfl.federated.task.runner import TaskRunner
from openfl.utilities import Metric, TensorKey, change_tags
from openfl.utilities.split import split_tensor_dict_for_holdouts

import xgboost as xgb
import json
from sklearn.metrics import accuracy_score


class XGBoostTaskRunner(TaskRunner):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -46,8 +45,13 @@ def rebuild_model(self, input_tensor_dict):
Returns:
None
"""
if (isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'].size != 0) \
or (not isinstance(input_tensor_dict['local_tree'], np.ndarray) and input_tensor_dict['local_tree'] is not None):
if (
isinstance(input_tensor_dict["local_tree"], np.ndarray)
and input_tensor_dict["local_tree"].size != 0
) or (
not isinstance(input_tensor_dict["local_tree"], np.ndarray)
and input_tensor_dict["local_tree"] is not None
):
self.set_tensor_dict(input_tensor_dict)

def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
Expand Down Expand Up @@ -179,50 +183,61 @@ def train_task(

# Return global_tensor_dict, local_tensor_dict
# import pdb; pdb.set_trace()
#TODO it is still decodable from here with .tobytes().decode('utf-8')
# TODO it is still decodable from here with .tobytes().decode('utf-8')
return global_tensor_dict, local_tensor_dict

def get_tensor_dict(self, with_opt_vars=False):
"""
Retrieves the tensor dictionary containing the model's tree structure.
This method returns a dictionary with the key 'local_tree', which contains the model's tree structure as a numpy array.
If the model has not been initialized (`self.bst` is None), it returns an empty numpy array.
If the global model is not set or is empty, it returns the entire model as a numpy array.
Otherwise, it returns only the trees added in the latest training session.
Parameters:
with_opt_vars (bool): N/A for XGBoost (Default=False).
Returns:
dict: A dictionary with the key 'local_tree' containing the model's tree structure as a numpy array.
"""

if self.bst is None:
# For initializing tensor dict
return {'local_tree': np.array([], dtype=np.float32)}
"""
Retrieves the tensor dictionary containing the model's tree structure.
booster_array = self.bst.save_raw('json')
booster_dict = json.loads(booster_array)
This method returns a dictionary with the key 'local_tree', which contains the model's tree structure as a numpy array.
If the model has not been initialized (`self.bst` is None), it returns an empty numpy array.
If the global model is not set or is empty, it returns the entire model as a numpy array.
Otherwise, it returns only the trees added in the latest training session.
if (isinstance(self.global_model, np.ndarray) and self.global_model.size == 0) or self.global_model is None:
booster_float32_array = np.frombuffer(booster_array, dtype=np.uint8).astype(np.float32)
return {'local_tree': booster_float32_array}
Parameters:
with_opt_vars (bool): N/A for XGBoost (Default=False).
global_model_booster_dict = json.loads(self.global_model)
num_global_trees = int(global_model_booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
num_total_trees = int(booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
Returns:
dict: A dictionary with the key 'local_tree' containing the model's tree structure as a numpy array.
"""

# Calculate the number of trees added in the latest training
num_latest_trees = num_total_trees - num_global_trees
latest_trees = booster_dict['learner']['gradient_booster']['model']['trees'][-num_latest_trees:]
if self.bst is None:
# For initializing tensor dict
return {"local_tree": np.array([], dtype=np.float32)}

booster_array = self.bst.save_raw("json")
booster_dict = json.loads(booster_array)

if (
isinstance(self.global_model, np.ndarray) and self.global_model.size == 0
) or self.global_model is None:
booster_float32_array = np.frombuffer(booster_array, dtype=np.uint8).astype(np.float32)
return {"local_tree": booster_float32_array}

global_model_booster_dict = json.loads(self.global_model)
num_global_trees = int(
global_model_booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"][
"num_trees"
]
)
num_total_trees = int(
booster_dict["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"]
)

latest_trees_json = json.dumps(latest_trees)
latest_trees_bytes = latest_trees_json.encode('utf-8')
latest_trees_float32_array = np.frombuffer(latest_trees_bytes, dtype=np.uint8).astype(np.float32)
# Calculate the number of trees added in the latest training
num_latest_trees = num_total_trees - num_global_trees
latest_trees = booster_dict["learner"]["gradient_booster"]["model"]["trees"][
-num_latest_trees:
]

return {'local_tree': latest_trees_float32_array}
latest_trees_json = json.dumps(latest_trees)
latest_trees_bytes = latest_trees_json.encode("utf-8")
latest_trees_float32_array = np.frombuffer(latest_trees_bytes, dtype=np.uint8).astype(
np.float32
)

return {"local_tree": latest_trees_float32_array}

def get_required_tensorkeys_for_function(self, func_name, **kwargs):
"""Get the required tensors for specified function that could be called
Expand Down Expand Up @@ -316,7 +331,7 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars=False):
with_opt_vars (bool): N/A for XGBoost (Default=False).
"""
# The with_opt_vars argument is not used in this method
self.global_model = bytearray(tensor_dict['local_tree'].astype(np.uint8).tobytes())
self.global_model = bytearray(tensor_dict["local_tree"].astype(np.uint8).tobytes())
self.bst = xgb.Booster()
self.bst.load_model(self.global_model)

Expand All @@ -338,21 +353,28 @@ def save_native(

def train_(self, train_dataloader) -> Metric:
"""Train model."""
dtrain = train_dataloader['dmatrix']
evals = [(dtrain, 'train')]
dtrain = train_dataloader["dmatrix"]
evals = [(dtrain, "train")]
evals_result = {}

self.bst = xgb.train(self.params, dtrain, self.num_rounds, xgb_model=self.bst,
evals=evals, evals_result=evals_result, verbose_eval=False)
self.bst = xgb.train(
self.params,
dtrain,
self.num_rounds,
xgb_model=self.bst,
evals=evals,
evals_result=evals_result,
verbose_eval=False,
)

loss = evals_result['train']['logloss'][-1]
loss = evals_result["train"]["logloss"][-1]
return Metric(name=self.loss_fn.__name__, value=np.array(loss))

def validate_(self, validation_dataloader) -> Metric:
"""Validate model."""

dtest = validation_dataloader['dmatrix']
y_test = validation_dataloader['labels']
dtest = validation_dataloader["dmatrix"]
y_test = validation_dataloader["labels"]
preds = self.bst.predict(dtest)
y_pred_binary = np.where(preds > 0.5, 1, 0)
acc = accuracy_score(y_test, y_pred_binary)
Expand Down
2 changes: 1 addition & 1 deletion openfl/interface/aggregation_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
)
from openfl.interface.aggregation_functions.adam_adaptive_aggregation import AdamAdaptiveAggregation
from openfl.interface.aggregation_functions.core import AggregationFunction
from openfl.interface.aggregation_functions.fed_bagging import FedBaggingXGBoost
from openfl.interface.aggregation_functions.fedcurv_weighted_average import FedCurvWeightedAverage
from openfl.interface.aggregation_functions.geometric_median import GeometricMedian
from openfl.interface.aggregation_functions.median import Median
from openfl.interface.aggregation_functions.weighted_average import WeightedAverage
from openfl.interface.aggregation_functions.yogi_adaptive_aggregation import YogiAdaptiveAggregation
from openfl.interface.aggregation_functions.fed_bagging import FedBaggingXGBoost
29 changes: 19 additions & 10 deletions openfl/interface/aggregation_functions/fed_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
"""Federated Boostrap Aggregation for XGBoost module."""

import json

import numpy as np

from openfl.interface.aggregation_functions.core import AggregationFunction


def get_global_model(iterator, target_round):
"""
Retrieves the global model for the specific round from an iterator.
Expand All @@ -21,8 +24,8 @@ def get_global_model(iterator, target_round):
"""
for item in iterator:
# Items tagged with ('model',) are the global model of that round
if 'tags' in item and item['tags'] == ('model',) and item['round'] == target_round:
return item['nparray']
if "tags" in item and item["tags"] == ("model",) and item["round"] == target_round:
return item["nparray"]
raise ValueError(f"No item found with tag 'model' and round {target_round}")


Expand All @@ -37,7 +40,9 @@ def append_trees(global_model, local_trees):
Returns:
dict: The updated global model with the local trees appended.
"""
num_global_trees = int(global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"])
num_global_trees = int(
global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"]
)
num_local_trees = len(local_trees)

global_model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]["num_trees"] = str(
Expand All @@ -47,9 +52,9 @@ def append_trees(global_model, local_trees):
num_global_trees + num_local_trees
)
for new_tree in range(num_local_trees):
local_trees[new_tree]["id"] = num_global_trees + new_tree
global_model["learner"]["gradient_booster"]["model"]["trees"].append(local_trees[new_tree])
global_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0)
local_trees[new_tree]["id"] = num_global_trees + new_tree
global_model["learner"]["gradient_booster"]["model"]["trees"].append(local_trees[new_tree])
global_model["learner"]["gradient_booster"]["model"]["tree_info"].append(0)

return global_model

Expand Down Expand Up @@ -93,18 +98,22 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):

global_model = get_global_model(db_iterator, fl_round)

if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None:
if (
isinstance(global_model, np.ndarray) and global_model.size == 0
) or global_model is None:
for local_tensor in local_tensors:
local_tree_bytearray = bytearray(local_tensor.tensor.astype(np.uint8).tobytes())
local_tree_json = json.loads(local_tree_bytearray)

if (isinstance(global_model, np.ndarray) and global_model.size == 0) or global_model is None:
if (
isinstance(global_model, np.ndarray) and global_model.size == 0
) or global_model is None:
# the first tree becomes the global model
global_model = local_tree_json
else:
# append subsequent trees to global model
local_model = local_tree_json
local_trees = local_model['learner']['gradient_booster']['model']['trees']
local_trees = local_model["learner"]["gradient_booster"]["model"]["trees"]
global_model = append_trees(global_model, local_trees)
else:
global_model_bytearray = bytearray(global_model.astype(np.uint8).tobytes())
Expand All @@ -116,6 +125,6 @@ def call(self, local_tensors, db_iterator, tensor_name, fl_round, *_):
global_model = append_trees(global_model, local_trees)

global_model_json = json.dumps(global_model)
global_model_bytes = global_model_json.encode('utf-8')
global_model_bytes = global_model_json.encode("utf-8")

return np.frombuffer(global_model_bytes, dtype=np.uint8).astype(np.float32)

0 comments on commit 51a0afa

Please sign in to comment.