Skip to content

Commit

Permalink
Merge branch 'main' into add_flare_client_context
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Jan 17, 2025
2 parents 001fe0f + e6bbd43 commit adf3c12
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions nvflare/app_common/widgets/validation_json_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import json
import os.path
from functools import singledispatch

import numpy as np

from nvflare.apis.dxo import DataKind, from_shareable, get_leaf_dxos
from nvflare.apis.event_type import EventType
Expand All @@ -23,6 +26,17 @@
from nvflare.widgets.widget import Widget


@singledispatch
def to_serializable(val):
"""Default json serializable method."""
return str(val)


@to_serializable.register(np.float32)
def ts_float32(val):
return np.float64(val)


class ValidationJsonGenerator(Widget):
def __init__(self, results_dir=AppConstants.CROSS_VAL_DIR, json_file_name="cross_val_results.json"):
"""Catches VALIDATION_RESULT_RECEIVED event and generates a results.json containing accuracy of each
Expand Down Expand Up @@ -58,7 +72,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
if val_results:
try:
dxo = from_shareable(val_results)
dxo.validate()

if dxo.data_kind == DataKind.METRICS:
if data_client not in self._val_results:
Expand All @@ -71,7 +84,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
for err in errors:
self.log_error(fl_ctx, f"Bad result from {data_client}: {err}")
for _sub_data_client, _dxo in leaf_dxos.items():
_dxo.validate()
if _sub_data_client not in self._val_results:
self._val_results[_sub_data_client] = {}
self._val_results[_sub_data_client][model_owner] = _dxo.data
Expand All @@ -93,4 +105,4 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):

res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
with open(res_file_path, "w") as f:
json.dump(self._val_results, f)
json.dump(self._val_results, f, default=to_serializable)

0 comments on commit adf3c12

Please sign in to comment.