From cbd4124d8e2caf289959e6b5ba0a8b11a2527cfa Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 5 Aug 2024 14:24:40 -0400 Subject: [PATCH 01/13] Made possible to set log_level from outside in MULTIPROCESS_SERVER --- .../online_callback_multiprocess_server.py | 9 +- bioptim/gui/online_callback_server.py | 86 +++++++++++-------- bioptim/interfaces/interface_utils.py | 16 ++-- 3 files changed, 66 insertions(+), 45 deletions(-) diff --git a/bioptim/gui/online_callback_multiprocess_server.py b/bioptim/gui/online_callback_multiprocess_server.py index 72ff50209..49176b4cf 100644 --- a/bioptim/gui/online_callback_multiprocess_server.py +++ b/bioptim/gui/online_callback_multiprocess_server.py @@ -3,7 +3,7 @@ from .online_callback_server import PlottingServer, OnlineCallbackServer -def _start_as_multiprocess_internal(**kwargs): +def _start_server_internal(**kwargs): """ Starts the server (necessary for multiprocessing), this method should not be called directly, apart from run_as_multiprocess @@ -26,7 +26,12 @@ def __init__(self, *args, **kwargs): """ host = kwargs["host"] if "host" in kwargs else None port = kwargs["port"] if "port" in kwargs else None - process = Process(target=_start_as_multiprocess_internal, kwargs={"host": host, "port": port}) + log_level = None + if "log_level" in kwargs: + log_level = kwargs["log_level"] + del kwargs["log_level"] + + process = Process(target=_start_server_internal, kwargs={"host": host, "port": port, "log_level": log_level}) process.start() super(OnlineCallbackMultiprocessServer, self).__init__(*args, **kwargs) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 62804bb7d..6b52ea727 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -27,18 +27,6 @@ def _deserialize_show_options(show_options: bytes) -> dict: return json.loads(show_options.decode()) -def _start_as_multiprocess_internal(*args, **kwargs): - """ - Starts the server (necessary for multiprocessing), this method should not be called directly, apart from - run_as_multiprocess - - Parameters - ---------- - same as PlottingServer - """ - PlottingServer(*args, **kwargs) - - class _ServerMessages(IntEnum): INITIATE_CONNEXION = auto() NEW_DATA = auto() @@ -49,7 +37,7 @@ class _ServerMessages(IntEnum): class PlottingServer: - def __init__(self, host: str = None, port: int = None): + def __init__(self, host: str = None, port: int = None, log_level: int = logging.INFO): """ Initializes the server @@ -59,9 +47,11 @@ def __init__(self, host: str = None, port: int = None): The host to listen to, by default "localhost" port: int The port to listen to, by default 3050 + log_level: int + The log level (see logging), by default logging.INFO """ - self._prepare_logger() + self._prepare_logger(log_level) self._get_data_interval = 1.0 self._update_plot_interval = 0.01 self._is_drawing = False @@ -74,9 +64,14 @@ def __init__(self, host: str = None, port: int = None): self._run() - def _prepare_logger(self) -> None: + def _prepare_logger(self, log_level: int) -> None: """ Prepares the logger + + Parameters + ---------- + log_level: int + The log level """ name = "PlottingServer" @@ -90,7 +85,7 @@ def _prepare_logger(self) -> None: self._logger = logging.getLogger(name) self._logger.addHandler(console_handler) - self._logger.setLevel(logging.INFO) + self._logger.setLevel(log_level) def _run(self) -> None: """ @@ -108,7 +103,10 @@ def _run(self) -> None: self._logger.info(f"Connection from {addr}") self._wait_for_new_connexion(client_socket) except Exception as e: - self._logger.error(f"Error while running the server: {e}") + self._logger.error( + f"Fatal error while running the server" + f"{"" if self._logger.level == logging.DEBUG else ", for more information set log_level to DEBUG"}") + self._logger.debug(f"Error: {e}") finally: self._socket.close() @@ -176,7 +174,7 @@ def _recv_message_type_and_data_len( if not data: return _ServerMessages.EMPTY, None except: - self._logger.warning("Client closed connexion") + self._logger.error("Client closed connexion") client_socket.close() return _ServerMessages.CLOSE_CONNEXION, None @@ -184,7 +182,7 @@ def _recv_message_type_and_data_len( try: message_type = _ServerMessages(int(data_as_list[0])) except ValueError: - self._logger.warning("Unknown message type received") + self._logger.error("Unknown message type received") # Sends failure if send_confirmation: client_socket.sendall("NOK".encode()) @@ -197,8 +195,9 @@ def _recv_message_type_and_data_len( try: len_all_data = [int(len_data) for len_data in data_as_list[1][1:-1].split(",")] - except ValueError: - self._logger.warning("Length of data could not be extracted") + except Exception as e: + self._logger.error("Length of data could not be extracted") + self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: client_socket.sendall("NOK".encode()) @@ -236,8 +235,9 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: data_out.append(client_socket.recv(len_data)) if len(data_out[-1]) != len_data: data_out[-1] += client_socket.recv(len_data - len(data_out[-1])) - except: - self._logger.warning("Unknown message type received") + except Exception as e: + self._logger.error("Unknown message type received") + self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: client_socket.sendall("NOK".encode()) @@ -267,24 +267,31 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No for phase_times in data_json["dummy_phase_times"]: dummy_time_vector.append([DM(v) for v in phase_times]) del data_json["dummy_phase_times"] - except: - self._logger.warning("Error while extracting dummy time vector from OCP data, closing connexion") - return + except Exception as e: + self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion") + client_socket.sendall("NOK".encode()) + raise e try: self.ocp = OcpSerializable.deserialize(data_json) - except: - client_socket.sendall("FAILED".encode()) - self._logger.warning("Error while deserializing OCP data from client, closing connexion") - return + except Exception as e: + self._logger.error("Error while deserializing OCP data from client, closing connexion") + client_socket.sendall("NOK".encode()) + raise e try: show_options = _deserialize_show_options(ocp_raw[1]) - except: - self._logger.warning("Error while extracting show options, closing connexion") - return + except Exception as e: + self._logger.error("Error while extracting show options, closing connexion") + client_socket.sendall("NOK".encode()) + raise e - self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) + try: + self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) + except Exception as e: + self._logger.error("Error while initializing the plotter, closing connexion") + client_socket.sendall("NOK".encode()) + raise e # Send the confirmation to the client client_socket.sendall("PLOT_READY".encode()) @@ -341,8 +348,10 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: # Give it some time time.sleep(self._update_plot_interval) client_socket.sendall("READY_FOR_NEXT_DATA".encode()) - except: - self._logger.warning("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") + except Exception as e: + self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") + self._logger.debug(f"Error: {e}") + client_socket.close() return should_continue = False @@ -351,8 +360,9 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: try: self._update_plot(data) should_continue = True - except: - self._logger.warning("Error while updating data from client, closing connexion") + except Exception as e: + self._logger.error("Error while updating data from client, closing connexion") + self._logger.debug(f"Error: {e}") client_socket.close() return diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index a0aa07afa..c273e1b83 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -60,15 +60,21 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): del show_options["port"] if online_optim == OnlineOptim.SERVER: - class_to_instantiate = OnlineCallbackServer + interface.options_common["iteration_callback"] = OnlineCallbackServer( + ocp, show_options=show_options, host=host, port=port + ) elif online_optim == OnlineOptim.MULTIPROCESS_SERVER: - class_to_instantiate = OnlineCallbackMultiprocessServer + log_level = None + if "log_level" in show_options: + log_level = show_options["log_level"] + del show_options["log_level"] + + interface.options_common["iteration_callback"] = OnlineCallbackMultiprocessServer( + ocp, show_options=show_options, host=host, port=port, log_level=log_level + ) else: raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") - interface.options_common["iteration_callback"] = class_to_instantiate( - ocp, show_options=show_options, host=host, port=port - ) else: raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") From 2905d3d5beef26862599208e1079701d39260138 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 5 Aug 2024 14:25:58 -0400 Subject: [PATCH 02/13] Removed the necessity to call "function" in plot to simplify serialization --- bioptim/__init__.py | 1 + bioptim/gui/plot.py | 48 +---- bioptim/gui/serializable_class.py | 194 +++++++----------- .../optimization/optimal_control_program.py | 67 ++++++ 4 files changed, 145 insertions(+), 165 deletions(-) diff --git a/bioptim/__init__.py b/bioptim/__init__.py index d07d8da3c..921bac294 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -231,4 +231,5 @@ from .optimization.problem_type import SocpType from .misc.casadi_expand import lt, le, gt, ge, if_else, if_else_zero +from .gui.plot import CustomPlot from .gui.online_callback_server import PlottingServer diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index 10bcfc223..55866e82c 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -320,7 +320,10 @@ def legend_without_duplicate_labels(ax): ax.legend(*zip(*unique)) variable_sizes = [] + + self.ocp.finalize_plot_phase_mappings() for i, nlp in enumerate(self.ocp.nlp): + variable_sizes.append({}) if nlp.plot: for key in nlp.plot: @@ -331,51 +334,6 @@ def legend_without_duplicate_labels(ax): if nlp.plot[key].node_idx is None: nlp.plot[key].node_idx = range(nlp.n_states_nodes) - # If the number of subplots is not known, compute it - if nlp.plot[key].phase_mappings is None: - node_index = nlp.plot[key].node_idx[0] - nlp.states.node_index = node_index - nlp.states_dot.node_index = node_index - nlp.controls.node_index = node_index - nlp.algebraic_states.node_index = node_index - - # If multi-node penalties = None, stays zero - size_x = nlp.states.shape - size_u = nlp.controls.shape - size_p = nlp.parameters.shape - size_a = nlp.algebraic_states.shape - size_d = nlp.numerical_timeseries.shape - if "penalty" in nlp.plot[key].parameters: - penalty = nlp.plot[key].parameters["penalty"] - - # As stated in penalty_option, the last controller is always supposed to be the right one - casadi_function = ( - penalty.function[0] if penalty.function[0] is not None else penalty.function[-1] - ) - if casadi_function is not None: - size_x = casadi_function.size_in("x")[0] - size_u = casadi_function.size_in("u")[0] - size_p = casadi_function.size_in("p")[0] - size_a = casadi_function.size_in("a")[0] - size_d = casadi_function.size_in("d")[0] - - size = ( - nlp.plot[key] - .function( - 0, # t0 - np.zeros(len(self.ocp.nlp)), # phases_dt - node_index, # node_idx - np.zeros((size_x, 1)), # states - np.zeros((size_u, 1)), # controls - np.zeros((size_p, 1)), # parameters - np.zeros((size_a, 1)), # algebraic_states - np.zeros((size_d, 1)), # numerical_timeseries - **nlp.plot[key].parameters, # parameters - ) - .shape[0] - ) - nlp.plot[key].phase_mappings = BiMapping(to_first=range(size), to_second=range(size)) - n_subplots = max(nlp.plot[key].phase_mappings.to_second.map_idx) + 1 if key not in variable_sizes[i]: diff --git a/bioptim/gui/serializable_class.py b/bioptim/gui/serializable_class.py index f53391500..e5f780355 100644 --- a/bioptim/gui/serializable_class.py +++ b/bioptim/gui/serializable_class.py @@ -1,6 +1,6 @@ -from typing import Any, Callable +from typing import Any -from casadi import DM, Function +from casadi import Function import numpy as np from ..dynamics.ode_solver import OdeSolver @@ -96,7 +96,7 @@ def from_mapping(cls, mapping): def serialize(self): return { - "map_idx": self.map_idx, + "map_idx": list(self.map_idx), "oppose": self.oppose, } @@ -194,7 +194,6 @@ def max(self): class CustomPlotSerializable: - _function: Callable type: PlotType phase_mappings: BiMappingSerializable legend: tuple | list @@ -207,12 +206,10 @@ class CustomPlotSerializable: label: list compute_derivative: bool integration_rule: QuadratureRule - parameters: dict[str, Any] all_variables_in_one_subplot: bool def __init__( self, - function: Callable, plot_type: PlotType, phase_mappings: BiMapping, legend: tuple | list, @@ -225,10 +222,8 @@ def __init__( label: list, compute_derivative: bool, integration_rule: QuadratureRule, - parameters: dict[str, Any], all_variables_in_one_subplot: bool, ): - self._function = function self.type = plot_type self.phase_mappings = phase_mappings self.legend = legend @@ -241,7 +236,6 @@ def __init__( self.label = label self.compute_derivative = compute_derivative self.integration_rule = integration_rule - self.parameters = parameters self.all_variables_in_one_subplot = all_variables_in_one_subplot @classmethod @@ -250,46 +244,9 @@ def from_custom_plot(cls, custom_plot): custom_plot: CustomPlot = custom_plot - _function = None - parameters = {} - for key in custom_plot.parameters.keys(): - if key == "penalty": - # This is a hack to emulate what PlotOcp._create_plots needs while not being able to actually serialize - # the function - parameters[key] = PenaltySerializable.from_penalty(custom_plot.parameters[key]) - - penalty = custom_plot.parameters[key] - - casadi_function = penalty.function[0] if penalty.function[0] is not None else penalty.function[-1] - size_x = casadi_function.size_in("x")[0] - size_dt = casadi_function.size_in("dt")[0] - size_u = casadi_function.size_in("u")[0] - size_p = casadi_function.size_in("p")[0] - size_a = casadi_function.size_in("a")[0] - size_d = casadi_function.size_in("d")[0] - _function = custom_plot.function( - 0, # t0 - np.zeros(size_dt), # phases_dt - custom_plot.node_idx[0], # node_idx - np.zeros((size_x, 1)), # states - np.zeros((size_u, 1)), # controls - np.zeros((size_p, 1)), # parameters - np.zeros((size_a, 1)), # algebraic_states - np.zeros((size_d, 1)), # numerical_timeseries - **custom_plot.parameters, # parameters - ) - - else: - raise NotImplementedError(f"Parameter {key} is not implemented in the serialization") - return cls( - function=_function, plot_type=custom_plot.type, - phase_mappings=( - None - if custom_plot.phase_mappings is None - else BiMappingSerializable.from_bimapping(custom_plot.phase_mappings) - ), + phase_mappings=BiMappingSerializable.from_bimapping(custom_plot.phase_mappings), legend=custom_plot.legend, combine_to=custom_plot.combine_to, color=custom_plot.color, @@ -300,45 +257,31 @@ def from_custom_plot(cls, custom_plot): label=custom_plot.label, compute_derivative=custom_plot.compute_derivative, integration_rule=custom_plot.integration_rule, - parameters=parameters, all_variables_in_one_subplot=custom_plot.all_variables_in_one_subplot, ) def serialize(self): return { - "function": None if self._function is None else np.array(self._function)[:, 0].tolist(), "type": self.type.value, - "phase_mappings": None if self.phase_mappings is None else self.phase_mappings.serialize(), + "phase_mappings": self.phase_mappings.serialize(), "legend": self.legend, "combine_to": self.combine_to, "color": self.color, "linestyle": self.linestyle, "ylim": self.ylim, "bounds": None if self.bounds is None else self.bounds.serialize(), - "node_idx": self.node_idx, + "node_idx": list(self.node_idx), "label": self.label, "compute_derivative": self.compute_derivative, "integration_rule": self.integration_rule.value, - "parameters": {key: param.serialize() for key, param in self.parameters.items()}, "all_variables_in_one_subplot": self.all_variables_in_one_subplot, } @classmethod def deserialize(cls, data): - - parameters = {} - for key in data["parameters"].keys(): - if key == "penalty": - parameters[key] = PenaltySerializable.deserialize(data["parameters"][key]) - else: - raise NotImplementedError(f"Parameter {key} is not implemented in the serialization") - return cls( - function=None if data["function"] is None else DM(data["function"]), plot_type=PlotType(data["type"]), - phase_mappings=( - None if data["phase_mappings"] is None else BiMappingSerializable.deserialize(data["phase_mappings"]) - ), + phase_mappings=BiMappingSerializable.deserialize(data["phase_mappings"]), legend=data["legend"], combine_to=data["combine_to"], color=data["color"], @@ -349,7 +292,6 @@ def deserialize(cls, data): label=data["label"], compute_derivative=data["compute_derivative"], integration_rule=QuadratureRule(data["integration_rule"]), - parameters=parameters, all_variables_in_one_subplot=data["all_variables_in_one_subplot"], ) @@ -438,6 +380,59 @@ def deserialize(cls, data): ) +class SaveIterationsInfoSerializable: + path_to_results: str + result_file_name: str | list[str] + nb_iter_save: int + current_iter: int + f_list: list[int] + + def __init__( + self, path_to_results: str, result_file_name: str, nb_iter_save: int, current_iter: int, f_list: list[int] + ): + self.path_to_results = path_to_results + self.result_file_name = result_file_name + self.nb_iter_save = nb_iter_save + self.current_iter = current_iter + self.f_list = f_list + + @classmethod + def from_save_iterations_info(cls, save_iterations_info): + from .ipopt_output_plot import SaveIterationsInfo + + save_iterations_info: SaveIterationsInfo = save_iterations_info + + if save_iterations_info is None: + return None + + return cls( + path_to_results=save_iterations_info.path_to_results, + result_file_name=save_iterations_info.result_file_name, + nb_iter_save=save_iterations_info.nb_iter_save, + current_iter=save_iterations_info.current_iter, + f_list=save_iterations_info.f_list, + ) + + def serialize(self): + return { + "path_to_results": self.path_to_results, + "result_file_name": self.result_file_name, + "nb_iter_save": self.nb_iter_save, + "current_iter": self.current_iter, + "f_list": self.f_list, + } + + @classmethod + def deserialize(cls, data): + return cls( + path_to_results=data["path_to_results"], + result_file_name=data["result_file_name"], + nb_iter_save=data["nb_iter_save"], + current_iter=data["current_iter"], + f_list=data["f_list"], + ) + + class NlpSerializable: ns: int phase_idx: int @@ -483,8 +478,6 @@ def __init__( def from_nlp(cls, nlp): from ..optimization.non_linear_program import NonLinearProgram - nlp: NonLinearProgram = nlp - return cls( ns=nlp.ns, phase_idx=nlp.phase_idx, @@ -531,59 +524,6 @@ def deserialize(cls, data): ) -class SaveIterationsInfoSerializable: - path_to_results: str - result_file_name: str | list[str] - nb_iter_save: int - current_iter: int - f_list: list[int] - - def __init__( - self, path_to_results: str, result_file_name: str, nb_iter_save: int, current_iter: int, f_list: list[int] - ): - self.path_to_results = path_to_results - self.result_file_name = result_file_name - self.nb_iter_save = nb_iter_save - self.current_iter = current_iter - self.f_list = f_list - - @classmethod - def from_save_iterations_info(cls, save_iterations_info): - from .ipopt_output_plot import SaveIterationsInfo - - save_iterations_info: SaveIterationsInfo = save_iterations_info - - if save_iterations_info is None: - return None - - return cls( - path_to_results=save_iterations_info.path_to_results, - result_file_name=save_iterations_info.result_file_name, - nb_iter_save=save_iterations_info.nb_iter_save, - current_iter=save_iterations_info.current_iter, - f_list=save_iterations_info.f_list, - ) - - def serialize(self): - return { - "path_to_results": self.path_to_results, - "result_file_name": self.result_file_name, - "nb_iter_save": self.nb_iter_save, - "current_iter": self.current_iter, - "f_list": self.f_list, - } - - @classmethod - def deserialize(cls, data): - return cls( - path_to_results=data["path_to_results"], - result_file_name=data["result_file_name"], - nb_iter_save=data["nb_iter_save"], - current_iter=data["current_iter"], - f_list=data["f_list"], - ) - - class OcpSerializable: n_phases: int nlp: list[NlpSerializable] @@ -617,6 +557,7 @@ def from_ocp(cls, ocp): from ..optimization.optimal_control_program import OptimalControlProgram ocp: OptimalControlProgram = ocp + ocp.finalize_plot_phase_mappings() return cls( n_phases=ocp.n_phases, @@ -655,3 +596,16 @@ def deserialize(cls, data): else SaveIterationsInfoSerializable.deserialize(data["save_ipopt_iterations_info"]) ), ) + + def finalize_plot_phase_mappings(self): + """ + This method can't be actually called from the serialized version, but we still can check if the work is done + """ + + for nlp in self.nlp: + if not nlp.plot: + continue + + for key in nlp.plot: + if nlp.plot[key].phase_mappings is None: + raise RuntimeError("The phase mapping should be set on client side") diff --git a/bioptim/optimization/optimal_control_program.py b/bioptim/optimization/optimal_control_program.py index b72acb03d..b1a006d9f 100644 --- a/bioptim/optimization/optimal_control_program.py +++ b/bioptim/optimization/optimal_control_program.py @@ -689,6 +689,73 @@ def _finalize_penalties( self.update_parameter_objectives(parameter_objectives) return + def finalize_plot_phase_mappings(self): + """ + Finalize the plot phase mappings (if not already done) + + Parameters + ---------- + n_phases: int + The number of phases + """ + + for nlp in self.nlp: + if not nlp.plot: + return + + for key in nlp.plot: + if isinstance(nlp.plot[key], tuple): + nlp.plot[key] = nlp.plot[key][0] + + # This is the point where we can safely define node_idx of the plot + if nlp.plot[key].node_idx is None: + nlp.plot[key].node_idx = range(nlp.n_states_nodes) + + # If the number of subplots is not known, compute it + if nlp.plot[key].phase_mappings is None: + node_index = nlp.plot[key].node_idx[0] + nlp.states.node_index = node_index + nlp.states_dot.node_index = node_index + nlp.controls.node_index = node_index + nlp.algebraic_states.node_index = node_index + + # If multi-node penalties = None, stays zero + size_x = nlp.states.shape + size_u = nlp.controls.shape + size_p = nlp.parameters.shape + size_a = nlp.algebraic_states.shape + size_d = nlp.numerical_timeseries.shape + if "penalty" in nlp.plot[key].parameters: + penalty = nlp.plot[key].parameters["penalty"] + + # As stated in penalty_option, the last controller is always supposed to be the right one + casadi_function = ( + penalty.function[0] if penalty.function[0] is not None else penalty.function[-1] + ) + if casadi_function is not None: + size_x = casadi_function.size_in("x")[0] + size_u = casadi_function.size_in("u")[0] + size_p = casadi_function.size_in("p")[0] + size_a = casadi_function.size_in("a")[0] + size_d = casadi_function.size_in("d")[0] + + size = ( + nlp.plot[key] + .function( + 0, # t0 + np.zeros(self.n_phases), # phases_dt + node_index, # node_idx + np.zeros((size_x, 1)), # states + np.zeros((size_u, 1)), # controls + np.zeros((size_p, 1)), # parameters + np.zeros((size_a, 1)), # algebraic_states + np.zeros((size_d, 1)), # numerical_timeseries + **nlp.plot[key].parameters, # parameters + ) + .shape[0] + ) + nlp.plot[key].phase_mappings = BiMapping(to_first=range(size), to_second=range(size)) + @property def variables_vector(self): return OptimizationVectorHelper.vector(self) From 4115740c7d06da5b4daac368d440b0a7876991e4 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 5 Aug 2024 15:19:38 -0400 Subject: [PATCH 03/13] Fixed default log_level --- bioptim/__init__.py | 2 +- bioptim/gui/online_callback_server.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/bioptim/__init__.py b/bioptim/__init__.py index 921bac294..99d9408cd 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -187,7 +187,7 @@ from .limits.multinode_constraint import MultinodeConstraintFcn, MultinodeConstraintList, MultinodeConstraint from .limits.multinode_objective import MultinodeObjectiveFcn, MultinodeObjectiveList, MultinodeObjective from .limits.objective_functions import ObjectiveFcn, ObjectiveList, Objective, ParameterObjectiveList -from .limits.path_conditions import BoundsList, InitialGuessList +from .limits.path_conditions import BoundsList, InitialGuessList, Bounds, InitialGuess from .limits.fatigue_path_conditions import FatigueBounds, FatigueInitialGuess from .limits.penalty_controller import PenaltyController from .limits.penalty_helpers import PenaltyHelpers diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 6b52ea727..020cad2a3 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -37,7 +37,7 @@ class _ServerMessages(IntEnum): class PlottingServer: - def __init__(self, host: str = None, port: int = None, log_level: int = logging.INFO): + def __init__(self, host: str = None, port: int = None, log_level: int | None = logging.INFO): """ Initializes the server @@ -51,6 +51,9 @@ def __init__(self, host: str = None, port: int = None, log_level: int = logging. The log level (see logging), by default logging.INFO """ + if log_level is None: + log_level = logging.INFO + self._prepare_logger(log_level) self._get_data_interval = 1.0 self._update_plot_interval = 0.01 From 78c51eefd3f930cb01a143a8c2682d1e74a465c0 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 5 Aug 2024 16:52:38 -0400 Subject: [PATCH 04/13] Made graphs work on Macos --- bioptim/examples/getting_started/pendulum.py | 3 +- bioptim/gui/online_callback_server.py | 53 +++++++++++++++----- resources/plotting_server.py | 11 ++-- 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/bioptim/examples/getting_started/pendulum.py b/bioptim/examples/getting_started/pendulum.py index 55fa6a782..674287a7e 100644 --- a/bioptim/examples/getting_started/pendulum.py +++ b/bioptim/examples/getting_started/pendulum.py @@ -150,7 +150,8 @@ def main(): # --- Solve the ocp --- # # Default is OnlineOptim.MULTIPROCESS on Linux, OnlineOptim.MULTIPROCESS_SERVER on Windows and None on MacOS - sol = ocp.solve(Solver.IPOPT(show_online_optim=OnlineOptim.DEFAULT)) + # To see the graphs on MacOS, one must run the server manually (see ressources/plotting_server.py) + sol = ocp.solve(Solver.IPOPT(online_optim=OnlineOptim.DEFAULT)) # --- Show the results graph --- # sol.print_cost() diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 020cad2a3..ea7f85ef9 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -1,6 +1,7 @@ from enum import IntEnum, auto import json import logging +import platform import socket import struct import time @@ -56,7 +57,7 @@ def __init__(self, host: str = None, port: int = None, log_level: int | None = l self._prepare_logger(log_level) self._get_data_interval = 1.0 - self._update_plot_interval = 0.01 + self._update_plot_interval = 100 self._is_drawing = False # Define the host and port @@ -65,6 +66,8 @@ def __init__(self, host: str = None, port: int = None, log_level: int | None = l self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._plotter: PlotOcp = None + self._should_send_ok_to_client_on_new_data = False + self._run() def _prepare_logger(self, log_level: int) -> None: @@ -103,7 +106,7 @@ def _run(self) -> None: while True: self._logger.info("Waiting for a new connexion") client_socket, addr = self._socket.accept() - self._logger.info(f"Connection from {addr}") + self._logger.info(f"Connexion from {addr}") self._wait_for_new_connexion(client_socket) except Exception as e: self._logger.error( @@ -177,7 +180,7 @@ def _recv_message_type_and_data_len( if not data: return _ServerMessages.EMPTY, None except: - self._logger.error("Client closed connexion") + self._logger.info("Client closed connexion") client_socket.close() return _ServerMessages.CLOSE_CONNEXION, None @@ -266,6 +269,19 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No try: data_json = json.loads(ocp_raw[0]) + except Exception as e: + self._logger.error("Error while converting data to json format, closing connexion") + client_socket.sendall("NOK".encode()) + raise e + + try: + self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"] + except Exception as e: + self._logger.error("Did not receive if confirmation should be sent, closing connexion") + client_socket.sendall("NOK".encode()) + raise e + + try: dummy_time_vector = [] for phase_times in data_json["dummy_phase_times"]: dummy_time_vector.append([DM(v) for v in phase_times]) @@ -301,7 +317,12 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() - threading.Timer(self._update_plot_interval, self._redraw).start() + + # Use the canvas timer for _redraw as threading won't work for updating the graphs on Macos + timer = self._plotter.all_figures[0].canvas.new_timer(self._update_plot_interval) + timer.add_callback(self._redraw) + timer.start() + plt.show() @property @@ -323,16 +344,12 @@ def _redraw(self) -> None: self._logger.debug("Updating plot") self._is_drawing = True - for _, fig in enumerate(self._plotter.all_figures): + for fig in self._plotter.all_figures: fig.canvas.draw() - fig.canvas.flush_events() + if platform.system() != "Darwin": + fig.canvas.flush_events() self._is_drawing = False - if self.has_at_least_one_active_figure: - threading.Timer(self._update_plot_interval, self._redraw).start() - else: - self._logger.info("All figures have been closed, stop updating the plots") - def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ Waits for new data from the client, sends a "READY_FOR_NEXT_DATA" message to the client to signal that the server @@ -347,7 +364,7 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: self._logger.debug(f"Waiting for new data from client") try: - if self._is_drawing: + if self._is_drawing and platform.system() != "Darwin" and not self._wait_for_new_data_to_plot: # Give it some time time.sleep(self._update_plot_interval) client_socket.sendall("READY_FOR_NEXT_DATA".encode()) @@ -358,7 +375,9 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: return should_continue = False - message_type, data = self._recv_data(client_socket=client_socket, send_confirmation=False) + message_type, data = self._recv_data( + client_socket=client_socket, send_confirmation=self._should_send_ok_to_client_on_new_data + ) if message_type == _ServerMessages.NEW_DATA: try: self._update_plot(data) @@ -417,6 +436,8 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str self._port = port if port else _DEFAULT_PORT self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._should_wait_ok_to_client_on_new_data = platform.system() == "Darwin" + if self.ocp.plot_ipopt_outputs: raise NotImplementedError("The online callback with TCP does not support the plot_ipopt_outputs option") if self.ocp.save_ipopt_iterations_info: @@ -462,6 +483,8 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: ocp_plot["dummy_phase_times"] = [] for phase_times in dummy_phase_times: ocp_plot["dummy_phase_times"].append([np.array(v)[:, 0].tolist() for v in phase_times]) + + ocp_plot["request_confirmation_on_new_data"] = self._should_wait_ok_to_client_on_new_data serialized_ocp = json.dumps(ocp_plot).encode() serialized_show_options = _serialize_show_options(show_options) @@ -533,9 +556,13 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) # If send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) + if self._socket.recv(1024).decode() != "OK": + raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(header) self._socket.sendall(data_serialized) # Again, if send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) + if self._socket.recv(1024).decode() != "OK": + raise RuntimeError("The server did not acknowledge the connexion") return [0] diff --git a/resources/plotting_server.py b/resources/plotting_server.py index 3ee6eb2da..b709a0cf8 100644 --- a/resources/plotting_server.py +++ b/resources/plotting_server.py @@ -1,11 +1,14 @@ """ -This file is an example of how to run a bioptim Online plotting server. That said, this is usually not the way to run a -bioptim server as it is easier to run it as an automatic multiprocess (default). This is achieved by setting -`show_options={"type": OnlineOptim.SERVER, "as_multiprocess": True}` in the solver options. -If set to False, then the plotting server is mandatory. +This file is an example of how to run a bioptim Online plotting server. Apart on Macos, this is usually not the way to run a +bioptim server as it is easier to run it as an automatic multiprocess. This is achieved by setting +`Solver.IPOPT(online_optim=OnlineOptim.MULTIPROCESS_SERVER)`. +If set to OnlineOptim.SERVER, then the plotting server is mandatory. Since the server runs usings sockets, it is possible to run the server on a different machine than the one running the optimization. This is useful when the optimization is run on a cluster and the plotting server is run on a local machine. + +On Macos, this server is necessary as it won't connect using multiprocess. One can simply run the current script on +another terminal to access the online graphs """ from bioptim import PlottingServer From fa04f9e0732c36ae618dd9ec0f8363639f8e6688 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 5 Aug 2024 17:01:49 -0400 Subject: [PATCH 05/13] Fixed non-waiting on Windows --- bioptim/gui/online_callback_server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index ea7f85ef9..f7d79317a 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -555,14 +555,14 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: header, data_serialized = _serialize_xydata(xdata, ydata) self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) - # If send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) - if self._socket.recv(1024).decode() != "OK": + if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") + self._socket.sendall(header) self._socket.sendall(data_serialized) - # Again, if send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) - if self._socket.recv(1024).decode() != "OK": + if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") + return [0] From 759e9abb792c82c96b72d9cc177d6d672c84a6e0 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 5 Aug 2024 17:11:13 -0400 Subject: [PATCH 06/13] Faster update --- bioptim/gui/online_callback_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index f7d79317a..796ecee4a 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -57,7 +57,7 @@ def __init__(self, host: str = None, port: int = None, log_level: int | None = l self._prepare_logger(log_level) self._get_data_interval = 1.0 - self._update_plot_interval = 100 + self._update_plot_interval = 10 self._is_drawing = False # Define the host and port From 3209971b1c47f24327db710c203f8646acae7e00 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Tue, 6 Aug 2024 09:55:47 -0400 Subject: [PATCH 07/13] Fixed not all data being downloaded on Linux server --- bioptim/gui/online_callback_server.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 796ecee4a..ce08576d8 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -58,7 +58,7 @@ def __init__(self, host: str = None, port: int = None, log_level: int | None = l self._prepare_logger(log_level) self._get_data_interval = 1.0 self._update_plot_interval = 10 - self._is_drawing = False + self._force_redraw = False # Define the host and port self._host = host if host else _DEFAULT_HOST @@ -85,7 +85,7 @@ def _prepare_logger(self, log_level: int) -> None: formatter = logging.Formatter( "{asctime} - {name}:{levelname} - {message}", style="{", - datefmt="%Y-%m-%d %H:%M", + datefmt="%Y-%m-%d %H:%M:%S.%03d", ) console_handler.setFormatter(formatter) @@ -238,9 +238,11 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: data_out = [] try: for len_data in len_all_data: - data_out.append(client_socket.recv(len_data)) - if len(data_out[-1]) != len_data: - data_out[-1] += client_socket.recv(len_data - len(data_out[-1])) + self._logger.debug(f"Waiting for {len_data} bytes from client") + data_tp = b"" + while len(data_tp) != len_data: + data_tp += client_socket.recv(len_data - len(data_tp)) + data_out.append(data_tp) except Exception as e: self._logger.error("Unknown message type received") self._logger.debug(f"Error: {e}") @@ -253,6 +255,7 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: if send_confirmation: client_socket.sendall("OK".encode()) + self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes") return data_out def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> None: @@ -343,12 +346,11 @@ def _redraw(self) -> None: """ self._logger.debug("Updating plot") - self._is_drawing = True for fig in self._plotter.all_figures: fig.canvas.draw() if platform.system() != "Darwin": fig.canvas.flush_events() - self._is_drawing = False + self._force_redraw = False def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ @@ -363,10 +365,11 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ self._logger.debug(f"Waiting for new data from client") + + if self._force_redraw and platform.system() != "Darwin": + time.sleep(self._update_plot_interval) + try: - if self._is_drawing and platform.system() != "Darwin" and not self._wait_for_new_data_to_plot: - # Give it some time - time.sleep(self._update_plot_interval) client_socket.sendall("READY_FOR_NEXT_DATA".encode()) except Exception as e: self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") @@ -407,6 +410,8 @@ def _update_plot(self, serialized_raw_data: list) -> None: self._logger.debug(f"Received new data from client") xdata, ydata = _deserialize_xydata(serialized_raw_data) self._plotter.update_data(xdata, ydata) + + self._force_redraw = True class OnlineCallbackServer(OnlineCallbackAbstract): From 0acb7fb3167331421168d47612b6eab2bfe1de7a Mon Sep 17 00:00:00 2001 From: eve-mac Date: Tue, 6 Aug 2024 17:33:35 -0400 Subject: [PATCH 08/13] error message --- bioptim/gui/online_callback_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index ce08576d8..aa35de394 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -111,7 +111,7 @@ def _run(self) -> None: except Exception as e: self._logger.error( f"Fatal error while running the server" - f"{"" if self._logger.level == logging.DEBUG else ", for more information set log_level to DEBUG"}") + f"{''if self._logger.level == logging.DEBUG else ', for more information set log_level to DEBUG'}") self._logger.debug(f"Error: {e}") finally: self._socket.close() From 5d11633c3de2067d1849db7bbf0d8412bf0f4f24 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Wed, 7 Aug 2024 12:03:10 -0400 Subject: [PATCH 09/13] Typos --- bioptim/examples/getting_started/pendulum.py | 2 +- bioptim/gui/online_callback_server.py | 23 ++++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/bioptim/examples/getting_started/pendulum.py b/bioptim/examples/getting_started/pendulum.py index 674287a7e..2e619e267 100644 --- a/bioptim/examples/getting_started/pendulum.py +++ b/bioptim/examples/getting_started/pendulum.py @@ -150,7 +150,7 @@ def main(): # --- Solve the ocp --- # # Default is OnlineOptim.MULTIPROCESS on Linux, OnlineOptim.MULTIPROCESS_SERVER on Windows and None on MacOS - # To see the graphs on MacOS, one must run the server manually (see ressources/plotting_server.py) + # To see the graphs on MacOS, one must run the server manually (see resources/plotting_server.py) sol = ocp.solve(Solver.IPOPT(online_optim=OnlineOptim.DEFAULT)) # --- Show the results graph --- # diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index aa35de394..62794d017 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -111,7 +111,8 @@ def _run(self) -> None: except Exception as e: self._logger.error( f"Fatal error while running the server" - f"{''if self._logger.level == logging.DEBUG else ', for more information set log_level to DEBUG'}") + f"{''if self._logger.level == logging.DEBUG else ', for more information set log_level to DEBUG'}" + ) self._logger.debug(f"Error: {e}") finally: self._socket.close() @@ -276,7 +277,7 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No self._logger.error("Error while converting data to json format, closing connexion") client_socket.sendall("NOK".encode()) raise e - + try: self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"] except Exception as e: @@ -320,12 +321,12 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() - + # Use the canvas timer for _redraw as threading won't work for updating the graphs on Macos timer = self._plotter.all_figures[0].canvas.new_timer(self._update_plot_interval) timer.add_callback(self._redraw) timer.start() - + plt.show() @property @@ -365,10 +366,10 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ self._logger.debug(f"Waiting for new data from client") - + if self._force_redraw and platform.system() != "Darwin": time.sleep(self._update_plot_interval) - + try: client_socket.sendall("READY_FOR_NEXT_DATA".encode()) except Exception as e: @@ -410,7 +411,7 @@ def _update_plot(self, serialized_raw_data: list) -> None: self._logger.debug(f"Received new data from client") xdata, ydata = _deserialize_xydata(serialized_raw_data) self._plotter.update_data(xdata, ydata) - + self._force_redraw = True @@ -472,7 +473,7 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: # Start the client try: self._socket.connect((self._host, self._port)) - except ConnectionError: + except: if retries > 5: raise RuntimeError( "Could not connect to the plotter server, make sure it is running by calling 'PlottingServer()' on " @@ -488,7 +489,7 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: ocp_plot["dummy_phase_times"] = [] for phase_times in dummy_phase_times: ocp_plot["dummy_phase_times"].append([np.array(v)[:, 0].tolist() for v in phase_times]) - + ocp_plot["request_confirmation_on_new_data"] = self._should_wait_ok_to_client_on_new_data serialized_ocp = json.dumps(ocp_plot).encode() @@ -562,12 +563,12 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") - + self._socket.sendall(header) self._socket.sendall(data_serialized) if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") - + return [0] From d2abac50664f5b6d98b7b0d9454418459c68fd54 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Wed, 7 Aug 2024 13:31:33 -0400 Subject: [PATCH 10/13] Simplified OnlineOptim parsing of options by using **kwargs --- bioptim/gui/online_callback_abstract.py | 2 +- bioptim/gui/online_callback_multiprocess.py | 4 +- bioptim/gui/online_callback_server.py | 10 ++-- bioptim/interfaces/interface_utils.py | 52 ++++----------------- bioptim/misc/enums.py | 12 +++++ 5 files changed, 29 insertions(+), 51 deletions(-) diff --git a/bioptim/gui/online_callback_abstract.py b/bioptim/gui/online_callback_abstract.py index 678da5567..f2c40fdf7 100644 --- a/bioptim/gui/online_callback_abstract.py +++ b/bioptim/gui/online_callback_abstract.py @@ -32,7 +32,7 @@ class OnlineCallbackAbstract(Callback, ABC): Send the current data to the plotter """ - def __init__(self, ocp, opts: dict = None, show_options: dict = None): + def __init__(self, ocp, opts: dict = None, **show_options): """ Parameters ---------- diff --git a/bioptim/gui/online_callback_multiprocess.py b/bioptim/gui/online_callback_multiprocess.py index 935538725..5b9ee5ded 100644 --- a/bioptim/gui/online_callback_multiprocess.py +++ b/bioptim/gui/online_callback_multiprocess.py @@ -24,8 +24,8 @@ class OnlineCallbackMultiprocess(OnlineCallbackAbstract): The multiprocessing placeholder """ - def __init__(self, ocp, opts: dict = None, show_options: dict = None): - super(OnlineCallbackMultiprocess, self).__init__(ocp, opts, show_options) + def __init__(self, ocp, opts: dict = None, **show_options): + super(OnlineCallbackMultiprocess, self).__init__(ocp, opts, **show_options) self.queue = mp.Queue() self.plotter = self.ProcessPlotter(self.ocp) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 62794d017..25a1eb996 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -416,7 +416,7 @@ def _update_plot(self, serialized_raw_data: list) -> None: class OnlineCallbackServer(OnlineCallbackAbstract): - def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str = None, port: int = None): + def __init__(self, ocp, opts: dict = None, host: str = None, port: int = None, **show_options): """ Initializes the client. This is not supposed to be called directly by the user, but by the solver. During the initialization, we need to perform some tasks that are not possible to do in server side. Then the results of @@ -436,10 +436,10 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str The port to connect to, by default 3050 """ - super().__init__(ocp, opts, show_options) - self._host = host if host else _DEFAULT_HOST self._port = port if port else _DEFAULT_PORT + + super().__init__(ocp, opts, **show_options) self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._should_wait_ok_to_client_on_new_data = platform.system() == "Darwin" @@ -477,8 +477,8 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: if retries > 5: raise RuntimeError( "Could not connect to the plotter server, make sure it is running by calling 'PlottingServer()' on " - "another python instance or allowing for automatic start of the server by calling " - "'PlottingServer.as_multiprocess()' in the main script" + "another python instance or allowing for automatic start (Linux or Windows) of the server setting " + "the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver" ) else: time.sleep(1) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index c273e1b83..c37bbcb31 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -32,51 +32,17 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): if show_options is None: show_options = {} - online_optim: OnlineOptim = interface.opts.online_optim - if online_optim == OnlineOptim.DEFAULT: - if platform == "linux": - online_optim = OnlineOptim.MULTIPROCESS - elif platform == "win32": - online_optim = OnlineOptim.MULTIPROCESS_SERVER - else: - online_optim = None - + online_optim = interface.opts.online_optim.get_default() if online_optim == OnlineOptim.MULTIPROCESS: - if platform != "linux": - raise RuntimeError( - "Online OnlineOptim.MULTIPROCESS is not supported on Windows or MacOS. " - "You can use online_optim=OnlineOptim.MULTIPROCESS_SERVER to the Solver declaration on Windows though" - ) - interface.options_common["iteration_callback"] = OnlineCallbackMultiprocess(ocp, show_options=show_options) - elif online_optim in (OnlineOptim.SERVER, OnlineOptim.MULTIPROCESS_SERVER): - host = None - if "host" in show_options: - host = show_options["host"] - del show_options["host"] - - port = None - if "port" in show_options: - port = show_options["port"] - del show_options["port"] - - if online_optim == OnlineOptim.SERVER: - interface.options_common["iteration_callback"] = OnlineCallbackServer( - ocp, show_options=show_options, host=host, port=port - ) - elif online_optim == OnlineOptim.MULTIPROCESS_SERVER: - log_level = None - if "log_level" in show_options: - log_level = show_options["log_level"] - del show_options["log_level"] - - interface.options_common["iteration_callback"] = OnlineCallbackMultiprocessServer( - ocp, show_options=show_options, host=host, port=port, log_level=log_level - ) - else: - raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") - + to_call = OnlineCallbackMultiprocess + elif online_optim == OnlineOptim.SERVER: + to_call = OnlineCallbackServer + elif online_optim == OnlineOptim.MULTIPROCESS_SERVER: + to_call = OnlineCallbackMultiprocessServer else: - raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") + raise ValueError(f"online_optim {online_optim} is not implemented yet") + + interface.options_common["iteration_callback"] = to_call(ocp, **show_options) def generic_solve(interface, expand_during_shake_tree=False) -> dict: diff --git a/bioptim/misc/enums.py b/bioptim/misc/enums.py index eaa396af7..fd93b6d89 100644 --- a/bioptim/misc/enums.py +++ b/bioptim/misc/enums.py @@ -1,4 +1,5 @@ from enum import Enum, IntEnum, auto +import platform class PhaseDynamics(Enum): @@ -111,6 +112,17 @@ class OnlineOptim(Enum): SERVER = auto() MULTIPROCESS_SERVER = auto() + def get_default(self): + if self != OnlineOptim.DEFAULT: + return self + + if platform.system() == "Linux": + return OnlineOptim.MULTIPROCESS + elif platform.system() == "Windows": + return OnlineOptim.MULTIPROCESS_SERVER + else: + return None + class ControlType(Enum): """ From 590692ce9e795d8655aa32418949b073b369f930 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Wed, 7 Aug 2024 13:37:17 -0400 Subject: [PATCH 11/13] Fixed time format on Windows --- bioptim/gui/online_callback_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 25a1eb996..ee54bb91e 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -85,7 +85,7 @@ def _prepare_logger(self, log_level: int) -> None: formatter = logging.Formatter( "{asctime} - {name}:{levelname} - {message}", style="{", - datefmt="%Y-%m-%d %H:%M:%S.%03d", + datefmt="%Y-%m-%d %H:%M:%S" if platform.system() == "Windows" else "%Y-%m-%d %H:%M:%S.%03d", ) console_handler.setFormatter(formatter) From cae1bcf4e82d2a00df300235b5697787684eef59 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 8 Aug 2024 10:00:30 -0400 Subject: [PATCH 12/13] Fixed a race condition by sending fixed size packet --- bioptim/gui/online_callback_server.py | 121 ++++++++++++++++++-------- bioptim/gui/utils.py | 14 +++ 2 files changed, 100 insertions(+), 35 deletions(-) create mode 100644 bioptim/gui/utils.py diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index ee54bb91e..08f422b18 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -1,4 +1,4 @@ -from enum import IntEnum, auto +from enum import IntEnum, StrEnum, auto import json import logging import platform @@ -13,6 +13,7 @@ from .online_callback_abstract import OnlineCallbackAbstract from .plot import PlotOcp, OcpSerializable +from .utils import strstaticproperty, intstaticproperty from ..optimization.optimization_vector import OptimizationVectorHelper @@ -37,6 +38,45 @@ class _ServerMessages(IntEnum): UNKNOWN = auto() +class _HeaderMessage(StrEnum): + _OK = "OK" + _NOK = "NOK" + _PLOT_READY = "PLOT_READY" + _READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA" + + @strstaticproperty + def OK() -> str: + return _HeaderMessage._to_str(_HeaderMessage._OK) + + @strstaticproperty + def NOK() -> str: + return _HeaderMessage._to_str(_HeaderMessage._NOK) + + @strstaticproperty + def PLOT_READY() -> str: + return _HeaderMessage._to_str(_HeaderMessage._PLOT_READY) + + @strstaticproperty + def READY_FOR_NEXT_DATA() -> str: + return _HeaderMessage._to_str(_HeaderMessage._READY_FOR_NEXT_DATA) + + @intstaticproperty + def longest() -> int: + return max(len(v) for v in _HeaderMessage.__members__.values()) + + @intstaticproperty + def header_len() -> int: + return _HeaderMessage.longest + 1 + + @intstaticproperty + def header_generic_len() -> int: + return 1024 + + @staticmethod + def _to_str(message: str) -> str: + return message.ljust(_HeaderMessage.header_len, "\0") + + class PlottingServer: def __init__(self, host: str = None, port: int = None, log_level: int | None = logging.INFO): """ @@ -141,8 +181,8 @@ def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> t client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol Returns ------- @@ -167,8 +207,8 @@ def _recv_message_type_and_data_len( client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol Returns ------- @@ -177,7 +217,7 @@ def _recv_message_type_and_data_len( # Receive the actual data try: - data = client_socket.recv(1024) + data = client_socket.recv(_HeaderMessage.header_generic_len).decode().strip("\0") if not data: return _ServerMessages.EMPTY, None except: @@ -185,14 +225,14 @@ def _recv_message_type_and_data_len( client_socket.close() return _ServerMessages.CLOSE_CONNEXION, None - data_as_list = data.decode().split("\n") + data_as_list = data.split("\n") try: message_type = _ServerMessages(int(data_as_list[0])) except ValueError: self._logger.error("Unknown message type received") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) return _ServerMessages.UNKNOWN, None if message_type == _ServerMessages.CLOSE_CONNEXION: @@ -207,13 +247,13 @@ def _recv_message_type_and_data_len( self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) return _ServerMessages.UNKNOWN, None # If we are here, everything went well, so send confirmation self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") if send_confirmation: - client_socket.sendall("OK".encode()) + client_socket.sendall(_HeaderMessage.OK.encode()) return message_type, len_all_data @@ -226,8 +266,8 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol len_all_data: list The length of the data to receive @@ -249,12 +289,12 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) return None # If we are here, everything went well, so send confirmation if send_confirmation: - client_socket.sendall("OK".encode()) + client_socket.sendall(_HeaderMessage.OK.encode()) self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes") return data_out @@ -275,14 +315,14 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No data_json = json.loads(ocp_raw[0]) except Exception as e: self._logger.error("Error while converting data to json format, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"] except Exception as e: self._logger.error("Did not receive if confirmation should be sent, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: @@ -292,32 +332,32 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No del data_json["dummy_phase_times"] except Exception as e: self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: self.ocp = OcpSerializable.deserialize(data_json) except Exception as e: self._logger.error("Error while deserializing OCP data from client, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: show_options = _deserialize_show_options(ocp_raw[1]) except Exception as e: self._logger.error("Error while extracting show options, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) except Exception as e: self._logger.error("Error while initializing the plotter, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e # Send the confirmation to the client - client_socket.sendall("PLOT_READY".encode()) + client_socket.sendall(_HeaderMessage.PLOT_READY.encode()) # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() @@ -355,9 +395,9 @@ def _redraw(self) -> None: def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ - Waits for new data from the client, sends a "READY_FOR_NEXT_DATA" message to the client to signal that the server - is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects - the connexion will be closed + Waits for new data from the client, sends a _HeaderMessage.READY_FOR_NEXT_DATA message to the client to signal + that the server is ready to receive new data. If the client sends new data, the server will update the plot, if + client disconnects the connexion will be closed Parameters ---------- @@ -371,7 +411,7 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: time.sleep(self._update_plot_interval) try: - client_socket.sendall("READY_FOR_NEXT_DATA".encode()) + client_socket.sendall(_HeaderMessage.READY_FOR_NEXT_DATA.encode()) except Exception as e: self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") self._logger.debug(f"Error: {e}") @@ -497,19 +537,20 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: # Sends message type and dimensions self._socket.sendall( - f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".encode() + f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".ljust( + _HeaderMessage.header_generic_len, "\0" + ).encode() ) - if self._socket.recv(1024).decode() != "OK": + if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(serialized_ocp) self._socket.sendall(serialized_show_options) - if self._socket.recv(1024).decode() != "OK": + if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: raise RuntimeError("The server did not acknowledge the connexion") # Wait for the server to be ready - data = self._socket.recv(1024).decode().split("\n") - if data[0] != "PLOT_READY": + if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.PLOT_READY: raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report") self._plotter = PlotOcp( @@ -545,8 +586,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.setblocking(False) try: - data = self._socket.recv(1024).decode() - if data != "READY_FOR_NEXT_DATA": + data = self._socket.recv(_HeaderMessage.header_len).decode() + if data != _HeaderMessage.READY_FOR_NEXT_DATA: return [0] except BlockingIOError: # This is to prevent the solving to be blocked by the server if it is not ready to update the plots @@ -560,13 +601,23 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: xdata, ydata = self._plotter.parse_data(**args_dict) header, data_serialized = _serialize_xydata(xdata, ydata) - self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) - if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": + self._socket.sendall( + f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".ljust( + _HeaderMessage.header_generic_len, "\0" + ).encode() + ) + if ( + self._should_wait_ok_to_client_on_new_data + and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK + ): raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(header) self._socket.sendall(data_serialized) - if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": + if ( + self._should_wait_ok_to_client_on_new_data + and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK + ): raise RuntimeError("The server did not acknowledge the connexion") return [0] diff --git a/bioptim/gui/utils.py b/bioptim/gui/utils.py new file mode 100644 index 000000000..bcb9460a9 --- /dev/null +++ b/bioptim/gui/utils.py @@ -0,0 +1,14 @@ +class strstaticproperty: + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner) -> str: + return self.func() + + +class intstaticproperty: + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner) -> int: + return self.func() From 7f4b0b04638fbc3481bcf62fb27f0756abf086f7 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 8 Aug 2024 15:41:08 -0400 Subject: [PATCH 13/13] Simplified ResponseHeader --- bioptim/gui/online_callback_server.py | 129 ++++++++++++-------------- bioptim/gui/utils.py | 14 --- tests/shard5/test_plot_server.py | 24 ++++- 3 files changed, 84 insertions(+), 83 deletions(-) delete mode 100644 bioptim/gui/utils.py diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 08f422b18..43fb96542 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -13,12 +13,12 @@ from .online_callback_abstract import OnlineCallbackAbstract from .plot import PlotOcp, OcpSerializable -from .utils import strstaticproperty, intstaticproperty from ..optimization.optimization_vector import OptimizationVectorHelper _DEFAULT_HOST = "localhost" _DEFAULT_PORT = 3050 +_HEADER_GENERIC_LEN = 1024 def _serialize_show_options(show_options: dict) -> bytes: @@ -38,43 +38,31 @@ class _ServerMessages(IntEnum): UNKNOWN = auto() -class _HeaderMessage(StrEnum): - _OK = "OK" - _NOK = "NOK" - _PLOT_READY = "PLOT_READY" - _READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA" +class _ResponseHeader(StrEnum): + OK = "OK" + NOK = "NOK" + PLOT_READY = "PLOT_READY" + READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA" - @strstaticproperty - def OK() -> str: - return _HeaderMessage._to_str(_HeaderMessage._OK) - - @strstaticproperty - def NOK() -> str: - return _HeaderMessage._to_str(_HeaderMessage._NOK) - - @strstaticproperty - def PLOT_READY() -> str: - return _HeaderMessage._to_str(_HeaderMessage._PLOT_READY) + @staticmethod + def longest() -> int: + return max(len(v) for v in _ResponseHeader.__members__) - @strstaticproperty - def READY_FOR_NEXT_DATA() -> str: - return _HeaderMessage._to_str(_HeaderMessage._READY_FOR_NEXT_DATA) + def encode(self) -> str: + return self.ljust(len(self), "\0").encode() - @intstaticproperty - def longest() -> int: - return max(len(v) for v in _HeaderMessage.__members__.values()) + @staticmethod + def response_len() -> int: + return _ResponseHeader.longest() + 1 - @intstaticproperty - def header_len() -> int: - return _HeaderMessage.longest + 1 + def __len__(self) -> int: + return _ResponseHeader.response_len() - @intstaticproperty - def header_generic_len() -> int: - return 1024 + def __eq__(self, value: object) -> bool: + return self.split("\0")[0] == value.split("\0")[0] or super().__eq__(value) - @staticmethod - def _to_str(message: str) -> str: - return message.ljust(_HeaderMessage.header_len, "\0") + def __ne__(self, value: object) -> bool: + return not self.__eq__(value) class PlottingServer: @@ -181,7 +169,7 @@ def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> t client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, otherwise it will not send anything. This is part of the communication protocol Returns @@ -207,7 +195,7 @@ def _recv_message_type_and_data_len( client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, otherwise it will not send anything. This is part of the communication protocol Returns @@ -217,7 +205,7 @@ def _recv_message_type_and_data_len( # Receive the actual data try: - data = client_socket.recv(_HeaderMessage.header_generic_len).decode().strip("\0") + data = client_socket.recv(_HEADER_GENERIC_LEN).decode().strip("\0") if not data: return _ServerMessages.EMPTY, None except: @@ -232,7 +220,7 @@ def _recv_message_type_and_data_len( self._logger.error("Unknown message type received") # Sends failure if send_confirmation: - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return _ServerMessages.UNKNOWN, None if message_type == _ServerMessages.CLOSE_CONNEXION: @@ -247,13 +235,13 @@ def _recv_message_type_and_data_len( self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return _ServerMessages.UNKNOWN, None # If we are here, everything went well, so send confirmation self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") if send_confirmation: - client_socket.sendall(_HeaderMessage.OK.encode()) + client_socket.sendall(_ResponseHeader.OK.encode()) return message_type, len_all_data @@ -266,7 +254,7 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, otherwise it will not send anything. This is part of the communication protocol len_all_data: list The length of the data to receive @@ -289,12 +277,12 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return None # If we are here, everything went well, so send confirmation if send_confirmation: - client_socket.sendall(_HeaderMessage.OK.encode()) + client_socket.sendall(_ResponseHeader.OK.encode()) self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes") return data_out @@ -315,14 +303,14 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No data_json = json.loads(ocp_raw[0]) except Exception as e: self._logger.error("Error while converting data to json format, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"] except Exception as e: self._logger.error("Did not receive if confirmation should be sent, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: @@ -332,32 +320,32 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No del data_json["dummy_phase_times"] except Exception as e: self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: self.ocp = OcpSerializable.deserialize(data_json) except Exception as e: self._logger.error("Error while deserializing OCP data from client, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: show_options = _deserialize_show_options(ocp_raw[1]) except Exception as e: self._logger.error("Error while extracting show options, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) except Exception as e: self._logger.error("Error while initializing the plotter, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e # Send the confirmation to the client - client_socket.sendall(_HeaderMessage.PLOT_READY.encode()) + client_socket.sendall(_ResponseHeader.PLOT_READY.encode()) # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() @@ -395,7 +383,7 @@ def _redraw(self) -> None: def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ - Waits for new data from the client, sends a _HeaderMessage.READY_FOR_NEXT_DATA message to the client to signal + Waits for new data from the client, sends a _ResponseHeader.READY_FOR_NEXT_DATA message to the client to signal that the server is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects the connexion will be closed @@ -411,7 +399,7 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: time.sleep(self._update_plot_interval) try: - client_socket.sendall(_HeaderMessage.READY_FOR_NEXT_DATA.encode()) + client_socket.sendall(_ResponseHeader.READY_FOR_NEXT_DATA.encode()) except Exception as e: self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") self._logger.debug(f"Error: {e}") @@ -476,10 +464,10 @@ def __init__(self, ocp, opts: dict = None, host: str = None, port: int = None, * The port to connect to, by default 3050 """ + super().__init__(ocp, opts, **show_options) + self._host = host if host else _DEFAULT_HOST self._port = port if port else _DEFAULT_PORT - - super().__init__(ocp, opts, **show_options) self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._should_wait_ok_to_client_on_new_data = platform.system() == "Darwin" @@ -518,7 +506,7 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: raise RuntimeError( "Could not connect to the plotter server, make sure it is running by calling 'PlottingServer()' on " "another python instance or allowing for automatic start (Linux or Windows) of the server setting " - "the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver" + "the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver." ) else: time.sleep(1) @@ -538,25 +526,36 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: # Sends message type and dimensions self._socket.sendall( f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".ljust( - _HeaderMessage.header_generic_len, "\0" + _HEADER_GENERIC_LEN, "\0" ).encode() ) - if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: + if not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(serialized_ocp) self._socket.sendall(serialized_show_options) - if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: + if not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") # Wait for the server to be ready - if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.PLOT_READY: + if self._socket.recv(_ResponseHeader.response_len()).decode() != _ResponseHeader.PLOT_READY: raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report") self._plotter = PlotOcp( self.ocp, only_initialize_variables=True, dummy_phase_times=dummy_phase_times, **show_options ) + def _has_received_ok(self) -> bool: + """ + Checks if the server has sent an OK message + + Returns + ------- + If the server has sent an OK message + """ + + return self._socket.recv(_ResponseHeader.response_len()).decode() == _ResponseHeader.OK + def close(self) -> None: """ Closes the connexion @@ -586,8 +585,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.setblocking(False) try: - data = self._socket.recv(_HeaderMessage.header_len).decode() - if data != _HeaderMessage.READY_FOR_NEXT_DATA: + data = self._socket.recv(_ResponseHeader.response_len()).decode() + if data != _ResponseHeader.READY_FOR_NEXT_DATA: return [0] except BlockingIOError: # This is to prevent the solving to be blocked by the server if it is not ready to update the plots @@ -603,21 +602,15 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.sendall( f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".ljust( - _HeaderMessage.header_generic_len, "\0" + _HEADER_GENERIC_LEN, "\0" ).encode() ) - if ( - self._should_wait_ok_to_client_on_new_data - and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK - ): + if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(header) self._socket.sendall(data_serialized) - if ( - self._should_wait_ok_to_client_on_new_data - and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK - ): + if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") return [0] diff --git a/bioptim/gui/utils.py b/bioptim/gui/utils.py deleted file mode 100644 index bcb9460a9..000000000 --- a/bioptim/gui/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -class strstaticproperty: - def __init__(self, func): - self.func = func - - def __get__(self, instance, owner) -> str: - return self.func() - - -class intstaticproperty: - def __init__(self, func): - self.func = func - - def __get__(self, instance, owner) -> int: - return self.func() diff --git a/tests/shard5/test_plot_server.py b/tests/shard5/test_plot_server.py index af774b260..442726ad5 100644 --- a/tests/shard5/test_plot_server.py +++ b/tests/shard5/test_plot_server.py @@ -2,8 +2,9 @@ from bioptim.gui.online_callback_server import _serialize_xydata, _deserialize_xydata from bioptim.gui.plot import PlotOcp +from bioptim.gui.online_callback_server import _ResponseHeader from bioptim.optimization.optimization_vector import OptimizationVectorHelper -from casadi import DM, Function +from casadi import DM import numpy as np @@ -40,3 +41,24 @@ def test_serialize_deserialize(): else: for y_phase, deserialized_y_phase in zip(y_variable, deserialized_y_variable): assert np.allclose(y_phase, deserialized_y_phase) + + +def test_response_header(): + # Make sure all the response have the same length + response_len = _ResponseHeader.response_len() + for response in _ResponseHeader: + assert len(response) == response_len + # Make sure encoding provides a constant length + assert len(response.encode()) == response_len + + # Make sure equality works + assert _ResponseHeader.OK == _ResponseHeader.OK + assert _ResponseHeader.OK.value == _ResponseHeader.OK + assert _ResponseHeader.OK.encode().decode() == _ResponseHeader.OK + assert _ResponseHeader.OK == _ResponseHeader.OK.encode().decode() + assert not (_ResponseHeader.OK != _ResponseHeader.OK) + assert not (_ResponseHeader.OK.encode().decode() != _ResponseHeader.OK) + assert not (_ResponseHeader.OK != _ResponseHeader.OK.encode().decode()) + assert not (_ResponseHeader.OK.value == _ResponseHeader.OK.encode().decode()) + assert _ResponseHeader.OK != _ResponseHeader.NOK + assert _ResponseHeader.NOK == _ResponseHeader.NOK