diff --git a/bioptim/__init__.py b/bioptim/__init__.py index d07d8da3c..99d9408cd 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -187,7 +187,7 @@ from .limits.multinode_constraint import MultinodeConstraintFcn, MultinodeConstraintList, MultinodeConstraint from .limits.multinode_objective import MultinodeObjectiveFcn, MultinodeObjectiveList, MultinodeObjective from .limits.objective_functions import ObjectiveFcn, ObjectiveList, Objective, ParameterObjectiveList -from .limits.path_conditions import BoundsList, InitialGuessList +from .limits.path_conditions import BoundsList, InitialGuessList, Bounds, InitialGuess from .limits.fatigue_path_conditions import FatigueBounds, FatigueInitialGuess from .limits.penalty_controller import PenaltyController from .limits.penalty_helpers import PenaltyHelpers @@ -231,4 +231,5 @@ from .optimization.problem_type import SocpType from .misc.casadi_expand import lt, le, gt, ge, if_else, if_else_zero +from .gui.plot import CustomPlot from .gui.online_callback_server import PlottingServer diff --git a/bioptim/examples/getting_started/pendulum.py b/bioptim/examples/getting_started/pendulum.py index 55fa6a782..2e619e267 100644 --- a/bioptim/examples/getting_started/pendulum.py +++ b/bioptim/examples/getting_started/pendulum.py @@ -150,7 +150,8 @@ def main(): # --- Solve the ocp --- # # Default is OnlineOptim.MULTIPROCESS on Linux, OnlineOptim.MULTIPROCESS_SERVER on Windows and None on MacOS - sol = ocp.solve(Solver.IPOPT(show_online_optim=OnlineOptim.DEFAULT)) + # To see the graphs on MacOS, one must run the server manually (see resources/plotting_server.py) + sol = ocp.solve(Solver.IPOPT(online_optim=OnlineOptim.DEFAULT)) # --- Show the results graph --- # sol.print_cost() diff --git a/bioptim/gui/online_callback_abstract.py b/bioptim/gui/online_callback_abstract.py index 678da5567..f2c40fdf7 100644 --- a/bioptim/gui/online_callback_abstract.py +++ b/bioptim/gui/online_callback_abstract.py @@ -32,7 +32,7 @@ class OnlineCallbackAbstract(Callback, ABC): Send the current data to the plotter """ - def __init__(self, ocp, opts: dict = None, show_options: dict = None): + def __init__(self, ocp, opts: dict = None, **show_options): """ Parameters ---------- diff --git a/bioptim/gui/online_callback_multiprocess.py b/bioptim/gui/online_callback_multiprocess.py index 935538725..5b9ee5ded 100644 --- a/bioptim/gui/online_callback_multiprocess.py +++ b/bioptim/gui/online_callback_multiprocess.py @@ -24,8 +24,8 @@ class OnlineCallbackMultiprocess(OnlineCallbackAbstract): The multiprocessing placeholder """ - def __init__(self, ocp, opts: dict = None, show_options: dict = None): - super(OnlineCallbackMultiprocess, self).__init__(ocp, opts, show_options) + def __init__(self, ocp, opts: dict = None, **show_options): + super(OnlineCallbackMultiprocess, self).__init__(ocp, opts, **show_options) self.queue = mp.Queue() self.plotter = self.ProcessPlotter(self.ocp) diff --git a/bioptim/gui/online_callback_multiprocess_server.py b/bioptim/gui/online_callback_multiprocess_server.py index 72ff50209..49176b4cf 100644 --- a/bioptim/gui/online_callback_multiprocess_server.py +++ b/bioptim/gui/online_callback_multiprocess_server.py @@ -3,7 +3,7 @@ from .online_callback_server import PlottingServer, OnlineCallbackServer -def _start_as_multiprocess_internal(**kwargs): +def _start_server_internal(**kwargs): """ Starts the server (necessary for multiprocessing), this method should not be called directly, apart from run_as_multiprocess @@ -26,7 +26,12 @@ def __init__(self, *args, **kwargs): """ host = kwargs["host"] if "host" in kwargs else None port = kwargs["port"] if "port" in kwargs else None - process = Process(target=_start_as_multiprocess_internal, kwargs={"host": host, "port": port}) + log_level = None + if "log_level" in kwargs: + log_level = kwargs["log_level"] + del kwargs["log_level"] + + process = Process(target=_start_server_internal, kwargs={"host": host, "port": port, "log_level": log_level}) process.start() super(OnlineCallbackMultiprocessServer, self).__init__(*args, **kwargs) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 62804bb7d..43fb96542 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -1,6 +1,7 @@ -from enum import IntEnum, auto +from enum import IntEnum, StrEnum, auto import json import logging +import platform import socket import struct import time @@ -17,6 +18,7 @@ _DEFAULT_HOST = "localhost" _DEFAULT_PORT = 3050 +_HEADER_GENERIC_LEN = 1024 def _serialize_show_options(show_options: dict) -> bytes: @@ -27,18 +29,6 @@ def _deserialize_show_options(show_options: bytes) -> dict: return json.loads(show_options.decode()) -def _start_as_multiprocess_internal(*args, **kwargs): - """ - Starts the server (necessary for multiprocessing), this method should not be called directly, apart from - run_as_multiprocess - - Parameters - ---------- - same as PlottingServer - """ - PlottingServer(*args, **kwargs) - - class _ServerMessages(IntEnum): INITIATE_CONNEXION = auto() NEW_DATA = auto() @@ -48,8 +38,35 @@ class _ServerMessages(IntEnum): UNKNOWN = auto() +class _ResponseHeader(StrEnum): + OK = "OK" + NOK = "NOK" + PLOT_READY = "PLOT_READY" + READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA" + + @staticmethod + def longest() -> int: + return max(len(v) for v in _ResponseHeader.__members__) + + def encode(self) -> str: + return self.ljust(len(self), "\0").encode() + + @staticmethod + def response_len() -> int: + return _ResponseHeader.longest() + 1 + + def __len__(self) -> int: + return _ResponseHeader.response_len() + + def __eq__(self, value: object) -> bool: + return self.split("\0")[0] == value.split("\0")[0] or super().__eq__(value) + + def __ne__(self, value: object) -> bool: + return not self.__eq__(value) + + class PlottingServer: - def __init__(self, host: str = None, port: int = None): + def __init__(self, host: str = None, port: int = None, log_level: int | None = logging.INFO): """ Initializes the server @@ -59,12 +76,17 @@ def __init__(self, host: str = None, port: int = None): The host to listen to, by default "localhost" port: int The port to listen to, by default 3050 + log_level: int + The log level (see logging), by default logging.INFO """ - self._prepare_logger() + if log_level is None: + log_level = logging.INFO + + self._prepare_logger(log_level) self._get_data_interval = 1.0 - self._update_plot_interval = 0.01 - self._is_drawing = False + self._update_plot_interval = 10 + self._force_redraw = False # Define the host and port self._host = host if host else _DEFAULT_HOST @@ -72,11 +94,18 @@ def __init__(self, host: str = None, port: int = None): self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._plotter: PlotOcp = None + self._should_send_ok_to_client_on_new_data = False + self._run() - def _prepare_logger(self) -> None: + def _prepare_logger(self, log_level: int) -> None: """ Prepares the logger + + Parameters + ---------- + log_level: int + The log level """ name = "PlottingServer" @@ -84,13 +113,13 @@ def _prepare_logger(self) -> None: formatter = logging.Formatter( "{asctime} - {name}:{levelname} - {message}", style="{", - datefmt="%Y-%m-%d %H:%M", + datefmt="%Y-%m-%d %H:%M:%S" if platform.system() == "Windows" else "%Y-%m-%d %H:%M:%S.%03d", ) console_handler.setFormatter(formatter) self._logger = logging.getLogger(name) self._logger.addHandler(console_handler) - self._logger.setLevel(logging.INFO) + self._logger.setLevel(log_level) def _run(self) -> None: """ @@ -105,10 +134,14 @@ def _run(self) -> None: while True: self._logger.info("Waiting for a new connexion") client_socket, addr = self._socket.accept() - self._logger.info(f"Connection from {addr}") + self._logger.info(f"Connexion from {addr}") self._wait_for_new_connexion(client_socket) except Exception as e: - self._logger.error(f"Error while running the server: {e}") + self._logger.error( + f"Fatal error while running the server" + f"{''if self._logger.level == logging.DEBUG else ', for more information set log_level to DEBUG'}" + ) + self._logger.debug(f"Error: {e}") finally: self._socket.close() @@ -136,8 +169,8 @@ def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> t client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol Returns ------- @@ -162,8 +195,8 @@ def _recv_message_type_and_data_len( client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol Returns ------- @@ -172,22 +205,22 @@ def _recv_message_type_and_data_len( # Receive the actual data try: - data = client_socket.recv(1024) + data = client_socket.recv(_HEADER_GENERIC_LEN).decode().strip("\0") if not data: return _ServerMessages.EMPTY, None except: - self._logger.warning("Client closed connexion") + self._logger.info("Client closed connexion") client_socket.close() return _ServerMessages.CLOSE_CONNEXION, None - data_as_list = data.decode().split("\n") + data_as_list = data.split("\n") try: message_type = _ServerMessages(int(data_as_list[0])) except ValueError: - self._logger.warning("Unknown message type received") + self._logger.error("Unknown message type received") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return _ServerMessages.UNKNOWN, None if message_type == _ServerMessages.CLOSE_CONNEXION: @@ -197,17 +230,18 @@ def _recv_message_type_and_data_len( try: len_all_data = [int(len_data) for len_data in data_as_list[1][1:-1].split(",")] - except ValueError: - self._logger.warning("Length of data could not be extracted") + except Exception as e: + self._logger.error("Length of data could not be extracted") + self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return _ServerMessages.UNKNOWN, None # If we are here, everything went well, so send confirmation self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") if send_confirmation: - client_socket.sendall("OK".encode()) + client_socket.sendall(_ResponseHeader.OK.encode()) return message_type, len_all_data @@ -220,8 +254,8 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol len_all_data: list The length of the data to receive @@ -233,20 +267,24 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: data_out = [] try: for len_data in len_all_data: - data_out.append(client_socket.recv(len_data)) - if len(data_out[-1]) != len_data: - data_out[-1] += client_socket.recv(len_data - len(data_out[-1])) - except: - self._logger.warning("Unknown message type received") + self._logger.debug(f"Waiting for {len_data} bytes from client") + data_tp = b"" + while len(data_tp) != len_data: + data_tp += client_socket.recv(len_data - len(data_tp)) + data_out.append(data_tp) + except Exception as e: + self._logger.error("Unknown message type received") + self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return None # If we are here, everything went well, so send confirmation if send_confirmation: - client_socket.sendall("OK".encode()) + client_socket.sendall(_ResponseHeader.OK.encode()) + self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes") return data_out def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> None: @@ -263,35 +301,60 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No try: data_json = json.loads(ocp_raw[0]) + except Exception as e: + self._logger.error("Error while converting data to json format, closing connexion") + client_socket.sendall(_ResponseHeader.NOK.encode()) + raise e + + try: + self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"] + except Exception as e: + self._logger.error("Did not receive if confirmation should be sent, closing connexion") + client_socket.sendall(_ResponseHeader.NOK.encode()) + raise e + + try: dummy_time_vector = [] for phase_times in data_json["dummy_phase_times"]: dummy_time_vector.append([DM(v) for v in phase_times]) del data_json["dummy_phase_times"] - except: - self._logger.warning("Error while extracting dummy time vector from OCP data, closing connexion") - return + except Exception as e: + self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion") + client_socket.sendall(_ResponseHeader.NOK.encode()) + raise e try: self.ocp = OcpSerializable.deserialize(data_json) - except: - client_socket.sendall("FAILED".encode()) - self._logger.warning("Error while deserializing OCP data from client, closing connexion") - return + except Exception as e: + self._logger.error("Error while deserializing OCP data from client, closing connexion") + client_socket.sendall(_ResponseHeader.NOK.encode()) + raise e try: show_options = _deserialize_show_options(ocp_raw[1]) - except: - self._logger.warning("Error while extracting show options, closing connexion") - return + except Exception as e: + self._logger.error("Error while extracting show options, closing connexion") + client_socket.sendall(_ResponseHeader.NOK.encode()) + raise e - self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) + try: + self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) + except Exception as e: + self._logger.error("Error while initializing the plotter, closing connexion") + client_socket.sendall(_ResponseHeader.NOK.encode()) + raise e # Send the confirmation to the client - client_socket.sendall("PLOT_READY".encode()) + client_socket.sendall(_ResponseHeader.PLOT_READY.encode()) # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() - threading.Timer(self._update_plot_interval, self._redraw).start() + + # Use the canvas timer for _redraw as threading won't work for updating the graphs on Macos + timer = self._plotter.all_figures[0].canvas.new_timer(self._update_plot_interval) + timer.add_callback(self._redraw) + timer.start() + plt.show() @property @@ -312,22 +375,17 @@ def _redraw(self) -> None: """ self._logger.debug("Updating plot") - self._is_drawing = True - for _, fig in enumerate(self._plotter.all_figures): + for fig in self._plotter.all_figures: fig.canvas.draw() - fig.canvas.flush_events() - self._is_drawing = False - - if self.has_at_least_one_active_figure: - threading.Timer(self._update_plot_interval, self._redraw).start() - else: - self._logger.info("All figures have been closed, stop updating the plots") + if platform.system() != "Darwin": + fig.canvas.flush_events() + self._force_redraw = False def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ - Waits for new data from the client, sends a "READY_FOR_NEXT_DATA" message to the client to signal that the server - is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects - the connexion will be closed + Waits for new data from the client, sends a _ResponseHeader.READY_FOR_NEXT_DATA message to the client to signal + that the server is ready to receive new data. If the client sends new data, the server will update the plot, if + client disconnects the connexion will be closed Parameters ---------- @@ -336,23 +394,29 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ self._logger.debug(f"Waiting for new data from client") + + if self._force_redraw and platform.system() != "Darwin": + time.sleep(self._update_plot_interval) + try: - if self._is_drawing: - # Give it some time - time.sleep(self._update_plot_interval) - client_socket.sendall("READY_FOR_NEXT_DATA".encode()) - except: - self._logger.warning("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") + client_socket.sendall(_ResponseHeader.READY_FOR_NEXT_DATA.encode()) + except Exception as e: + self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") + self._logger.debug(f"Error: {e}") + client_socket.close() return should_continue = False - message_type, data = self._recv_data(client_socket=client_socket, send_confirmation=False) + message_type, data = self._recv_data( + client_socket=client_socket, send_confirmation=self._should_send_ok_to_client_on_new_data + ) if message_type == _ServerMessages.NEW_DATA: try: self._update_plot(data) should_continue = True - except: - self._logger.warning("Error while updating data from client, closing connexion") + except Exception as e: + self._logger.error("Error while updating data from client, closing connexion") + self._logger.debug(f"Error: {e}") client_socket.close() return @@ -376,9 +440,11 @@ def _update_plot(self, serialized_raw_data: list) -> None: xdata, ydata = _deserialize_xydata(serialized_raw_data) self._plotter.update_data(xdata, ydata) + self._force_redraw = True + class OnlineCallbackServer(OnlineCallbackAbstract): - def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str = None, port: int = None): + def __init__(self, ocp, opts: dict = None, host: str = None, port: int = None, **show_options): """ Initializes the client. This is not supposed to be called directly by the user, but by the solver. During the initialization, we need to perform some tasks that are not possible to do in server side. Then the results of @@ -398,12 +464,14 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str The port to connect to, by default 3050 """ - super().__init__(ocp, opts, show_options) + super().__init__(ocp, opts, **show_options) self._host = host if host else _DEFAULT_HOST self._port = port if port else _DEFAULT_PORT self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._should_wait_ok_to_client_on_new_data = platform.system() == "Darwin" + if self.ocp.plot_ipopt_outputs: raise NotImplementedError("The online callback with TCP does not support the plot_ipopt_outputs option") if self.ocp.save_ipopt_iterations_info: @@ -433,12 +501,12 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: # Start the client try: self._socket.connect((self._host, self._port)) - except ConnectionError: + except: if retries > 5: raise RuntimeError( "Could not connect to the plotter server, make sure it is running by calling 'PlottingServer()' on " - "another python instance or allowing for automatic start of the server by calling " - "'PlottingServer.as_multiprocess()' in the main script" + "another python instance or allowing for automatic start (Linux or Windows) of the server setting " + "the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver." ) else: time.sleep(1) @@ -449,31 +517,45 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: ocp_plot["dummy_phase_times"] = [] for phase_times in dummy_phase_times: ocp_plot["dummy_phase_times"].append([np.array(v)[:, 0].tolist() for v in phase_times]) + + ocp_plot["request_confirmation_on_new_data"] = self._should_wait_ok_to_client_on_new_data serialized_ocp = json.dumps(ocp_plot).encode() serialized_show_options = _serialize_show_options(show_options) # Sends message type and dimensions self._socket.sendall( - f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".encode() + f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".ljust( + _HEADER_GENERIC_LEN, "\0" + ).encode() ) - if self._socket.recv(1024).decode() != "OK": + if not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(serialized_ocp) self._socket.sendall(serialized_show_options) - if self._socket.recv(1024).decode() != "OK": + if not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") # Wait for the server to be ready - data = self._socket.recv(1024).decode().split("\n") - if data[0] != "PLOT_READY": + if self._socket.recv(_ResponseHeader.response_len()).decode() != _ResponseHeader.PLOT_READY: raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report") self._plotter = PlotOcp( self.ocp, only_initialize_variables=True, dummy_phase_times=dummy_phase_times, **show_options ) + def _has_received_ok(self) -> bool: + """ + Checks if the server has sent an OK message + + Returns + ------- + If the server has sent an OK message + """ + + return self._socket.recv(_ResponseHeader.response_len()).decode() == _ResponseHeader.OK + def close(self) -> None: """ Closes the connexion @@ -503,8 +585,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.setblocking(False) try: - data = self._socket.recv(1024).decode() - if data != "READY_FOR_NEXT_DATA": + data = self._socket.recv(_ResponseHeader.response_len()).decode() + if data != _ResponseHeader.READY_FOR_NEXT_DATA: return [0] except BlockingIOError: # This is to prevent the solving to be blocked by the server if it is not ready to update the plots @@ -518,11 +600,19 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: xdata, ydata = self._plotter.parse_data(**args_dict) header, data_serialized = _serialize_xydata(xdata, ydata) - self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) - # If send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) + self._socket.sendall( + f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".ljust( + _HEADER_GENERIC_LEN, "\0" + ).encode() + ) + if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok(): + raise RuntimeError("The server did not acknowledge the connexion") + self._socket.sendall(header) self._socket.sendall(data_serialized) - # Again, if send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) + if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok(): + raise RuntimeError("The server did not acknowledge the connexion") + return [0] diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index 10bcfc223..55866e82c 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -320,7 +320,10 @@ def legend_without_duplicate_labels(ax): ax.legend(*zip(*unique)) variable_sizes = [] + + self.ocp.finalize_plot_phase_mappings() for i, nlp in enumerate(self.ocp.nlp): + variable_sizes.append({}) if nlp.plot: for key in nlp.plot: @@ -331,51 +334,6 @@ def legend_without_duplicate_labels(ax): if nlp.plot[key].node_idx is None: nlp.plot[key].node_idx = range(nlp.n_states_nodes) - # If the number of subplots is not known, compute it - if nlp.plot[key].phase_mappings is None: - node_index = nlp.plot[key].node_idx[0] - nlp.states.node_index = node_index - nlp.states_dot.node_index = node_index - nlp.controls.node_index = node_index - nlp.algebraic_states.node_index = node_index - - # If multi-node penalties = None, stays zero - size_x = nlp.states.shape - size_u = nlp.controls.shape - size_p = nlp.parameters.shape - size_a = nlp.algebraic_states.shape - size_d = nlp.numerical_timeseries.shape - if "penalty" in nlp.plot[key].parameters: - penalty = nlp.plot[key].parameters["penalty"] - - # As stated in penalty_option, the last controller is always supposed to be the right one - casadi_function = ( - penalty.function[0] if penalty.function[0] is not None else penalty.function[-1] - ) - if casadi_function is not None: - size_x = casadi_function.size_in("x")[0] - size_u = casadi_function.size_in("u")[0] - size_p = casadi_function.size_in("p")[0] - size_a = casadi_function.size_in("a")[0] - size_d = casadi_function.size_in("d")[0] - - size = ( - nlp.plot[key] - .function( - 0, # t0 - np.zeros(len(self.ocp.nlp)), # phases_dt - node_index, # node_idx - np.zeros((size_x, 1)), # states - np.zeros((size_u, 1)), # controls - np.zeros((size_p, 1)), # parameters - np.zeros((size_a, 1)), # algebraic_states - np.zeros((size_d, 1)), # numerical_timeseries - **nlp.plot[key].parameters, # parameters - ) - .shape[0] - ) - nlp.plot[key].phase_mappings = BiMapping(to_first=range(size), to_second=range(size)) - n_subplots = max(nlp.plot[key].phase_mappings.to_second.map_idx) + 1 if key not in variable_sizes[i]: diff --git a/bioptim/gui/serializable_class.py b/bioptim/gui/serializable_class.py index f53391500..e5f780355 100644 --- a/bioptim/gui/serializable_class.py +++ b/bioptim/gui/serializable_class.py @@ -1,6 +1,6 @@ -from typing import Any, Callable +from typing import Any -from casadi import DM, Function +from casadi import Function import numpy as np from ..dynamics.ode_solver import OdeSolver @@ -96,7 +96,7 @@ def from_mapping(cls, mapping): def serialize(self): return { - "map_idx": self.map_idx, + "map_idx": list(self.map_idx), "oppose": self.oppose, } @@ -194,7 +194,6 @@ def max(self): class CustomPlotSerializable: - _function: Callable type: PlotType phase_mappings: BiMappingSerializable legend: tuple | list @@ -207,12 +206,10 @@ class CustomPlotSerializable: label: list compute_derivative: bool integration_rule: QuadratureRule - parameters: dict[str, Any] all_variables_in_one_subplot: bool def __init__( self, - function: Callable, plot_type: PlotType, phase_mappings: BiMapping, legend: tuple | list, @@ -225,10 +222,8 @@ def __init__( label: list, compute_derivative: bool, integration_rule: QuadratureRule, - parameters: dict[str, Any], all_variables_in_one_subplot: bool, ): - self._function = function self.type = plot_type self.phase_mappings = phase_mappings self.legend = legend @@ -241,7 +236,6 @@ def __init__( self.label = label self.compute_derivative = compute_derivative self.integration_rule = integration_rule - self.parameters = parameters self.all_variables_in_one_subplot = all_variables_in_one_subplot @classmethod @@ -250,46 +244,9 @@ def from_custom_plot(cls, custom_plot): custom_plot: CustomPlot = custom_plot - _function = None - parameters = {} - for key in custom_plot.parameters.keys(): - if key == "penalty": - # This is a hack to emulate what PlotOcp._create_plots needs while not being able to actually serialize - # the function - parameters[key] = PenaltySerializable.from_penalty(custom_plot.parameters[key]) - - penalty = custom_plot.parameters[key] - - casadi_function = penalty.function[0] if penalty.function[0] is not None else penalty.function[-1] - size_x = casadi_function.size_in("x")[0] - size_dt = casadi_function.size_in("dt")[0] - size_u = casadi_function.size_in("u")[0] - size_p = casadi_function.size_in("p")[0] - size_a = casadi_function.size_in("a")[0] - size_d = casadi_function.size_in("d")[0] - _function = custom_plot.function( - 0, # t0 - np.zeros(size_dt), # phases_dt - custom_plot.node_idx[0], # node_idx - np.zeros((size_x, 1)), # states - np.zeros((size_u, 1)), # controls - np.zeros((size_p, 1)), # parameters - np.zeros((size_a, 1)), # algebraic_states - np.zeros((size_d, 1)), # numerical_timeseries - **custom_plot.parameters, # parameters - ) - - else: - raise NotImplementedError(f"Parameter {key} is not implemented in the serialization") - return cls( - function=_function, plot_type=custom_plot.type, - phase_mappings=( - None - if custom_plot.phase_mappings is None - else BiMappingSerializable.from_bimapping(custom_plot.phase_mappings) - ), + phase_mappings=BiMappingSerializable.from_bimapping(custom_plot.phase_mappings), legend=custom_plot.legend, combine_to=custom_plot.combine_to, color=custom_plot.color, @@ -300,45 +257,31 @@ def from_custom_plot(cls, custom_plot): label=custom_plot.label, compute_derivative=custom_plot.compute_derivative, integration_rule=custom_plot.integration_rule, - parameters=parameters, all_variables_in_one_subplot=custom_plot.all_variables_in_one_subplot, ) def serialize(self): return { - "function": None if self._function is None else np.array(self._function)[:, 0].tolist(), "type": self.type.value, - "phase_mappings": None if self.phase_mappings is None else self.phase_mappings.serialize(), + "phase_mappings": self.phase_mappings.serialize(), "legend": self.legend, "combine_to": self.combine_to, "color": self.color, "linestyle": self.linestyle, "ylim": self.ylim, "bounds": None if self.bounds is None else self.bounds.serialize(), - "node_idx": self.node_idx, + "node_idx": list(self.node_idx), "label": self.label, "compute_derivative": self.compute_derivative, "integration_rule": self.integration_rule.value, - "parameters": {key: param.serialize() for key, param in self.parameters.items()}, "all_variables_in_one_subplot": self.all_variables_in_one_subplot, } @classmethod def deserialize(cls, data): - - parameters = {} - for key in data["parameters"].keys(): - if key == "penalty": - parameters[key] = PenaltySerializable.deserialize(data["parameters"][key]) - else: - raise NotImplementedError(f"Parameter {key} is not implemented in the serialization") - return cls( - function=None if data["function"] is None else DM(data["function"]), plot_type=PlotType(data["type"]), - phase_mappings=( - None if data["phase_mappings"] is None else BiMappingSerializable.deserialize(data["phase_mappings"]) - ), + phase_mappings=BiMappingSerializable.deserialize(data["phase_mappings"]), legend=data["legend"], combine_to=data["combine_to"], color=data["color"], @@ -349,7 +292,6 @@ def deserialize(cls, data): label=data["label"], compute_derivative=data["compute_derivative"], integration_rule=QuadratureRule(data["integration_rule"]), - parameters=parameters, all_variables_in_one_subplot=data["all_variables_in_one_subplot"], ) @@ -438,6 +380,59 @@ def deserialize(cls, data): ) +class SaveIterationsInfoSerializable: + path_to_results: str + result_file_name: str | list[str] + nb_iter_save: int + current_iter: int + f_list: list[int] + + def __init__( + self, path_to_results: str, result_file_name: str, nb_iter_save: int, current_iter: int, f_list: list[int] + ): + self.path_to_results = path_to_results + self.result_file_name = result_file_name + self.nb_iter_save = nb_iter_save + self.current_iter = current_iter + self.f_list = f_list + + @classmethod + def from_save_iterations_info(cls, save_iterations_info): + from .ipopt_output_plot import SaveIterationsInfo + + save_iterations_info: SaveIterationsInfo = save_iterations_info + + if save_iterations_info is None: + return None + + return cls( + path_to_results=save_iterations_info.path_to_results, + result_file_name=save_iterations_info.result_file_name, + nb_iter_save=save_iterations_info.nb_iter_save, + current_iter=save_iterations_info.current_iter, + f_list=save_iterations_info.f_list, + ) + + def serialize(self): + return { + "path_to_results": self.path_to_results, + "result_file_name": self.result_file_name, + "nb_iter_save": self.nb_iter_save, + "current_iter": self.current_iter, + "f_list": self.f_list, + } + + @classmethod + def deserialize(cls, data): + return cls( + path_to_results=data["path_to_results"], + result_file_name=data["result_file_name"], + nb_iter_save=data["nb_iter_save"], + current_iter=data["current_iter"], + f_list=data["f_list"], + ) + + class NlpSerializable: ns: int phase_idx: int @@ -483,8 +478,6 @@ def __init__( def from_nlp(cls, nlp): from ..optimization.non_linear_program import NonLinearProgram - nlp: NonLinearProgram = nlp - return cls( ns=nlp.ns, phase_idx=nlp.phase_idx, @@ -531,59 +524,6 @@ def deserialize(cls, data): ) -class SaveIterationsInfoSerializable: - path_to_results: str - result_file_name: str | list[str] - nb_iter_save: int - current_iter: int - f_list: list[int] - - def __init__( - self, path_to_results: str, result_file_name: str, nb_iter_save: int, current_iter: int, f_list: list[int] - ): - self.path_to_results = path_to_results - self.result_file_name = result_file_name - self.nb_iter_save = nb_iter_save - self.current_iter = current_iter - self.f_list = f_list - - @classmethod - def from_save_iterations_info(cls, save_iterations_info): - from .ipopt_output_plot import SaveIterationsInfo - - save_iterations_info: SaveIterationsInfo = save_iterations_info - - if save_iterations_info is None: - return None - - return cls( - path_to_results=save_iterations_info.path_to_results, - result_file_name=save_iterations_info.result_file_name, - nb_iter_save=save_iterations_info.nb_iter_save, - current_iter=save_iterations_info.current_iter, - f_list=save_iterations_info.f_list, - ) - - def serialize(self): - return { - "path_to_results": self.path_to_results, - "result_file_name": self.result_file_name, - "nb_iter_save": self.nb_iter_save, - "current_iter": self.current_iter, - "f_list": self.f_list, - } - - @classmethod - def deserialize(cls, data): - return cls( - path_to_results=data["path_to_results"], - result_file_name=data["result_file_name"], - nb_iter_save=data["nb_iter_save"], - current_iter=data["current_iter"], - f_list=data["f_list"], - ) - - class OcpSerializable: n_phases: int nlp: list[NlpSerializable] @@ -617,6 +557,7 @@ def from_ocp(cls, ocp): from ..optimization.optimal_control_program import OptimalControlProgram ocp: OptimalControlProgram = ocp + ocp.finalize_plot_phase_mappings() return cls( n_phases=ocp.n_phases, @@ -655,3 +596,16 @@ def deserialize(cls, data): else SaveIterationsInfoSerializable.deserialize(data["save_ipopt_iterations_info"]) ), ) + + def finalize_plot_phase_mappings(self): + """ + This method can't be actually called from the serialized version, but we still can check if the work is done + """ + + for nlp in self.nlp: + if not nlp.plot: + continue + + for key in nlp.plot: + if nlp.plot[key].phase_mappings is None: + raise RuntimeError("The phase mapping should be set on client side") diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index a0aa07afa..c37bbcb31 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -32,45 +32,17 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): if show_options is None: show_options = {} - online_optim: OnlineOptim = interface.opts.online_optim - if online_optim == OnlineOptim.DEFAULT: - if platform == "linux": - online_optim = OnlineOptim.MULTIPROCESS - elif platform == "win32": - online_optim = OnlineOptim.MULTIPROCESS_SERVER - else: - online_optim = None - + online_optim = interface.opts.online_optim.get_default() if online_optim == OnlineOptim.MULTIPROCESS: - if platform != "linux": - raise RuntimeError( - "Online OnlineOptim.MULTIPROCESS is not supported on Windows or MacOS. " - "You can use online_optim=OnlineOptim.MULTIPROCESS_SERVER to the Solver declaration on Windows though" - ) - interface.options_common["iteration_callback"] = OnlineCallbackMultiprocess(ocp, show_options=show_options) - elif online_optim in (OnlineOptim.SERVER, OnlineOptim.MULTIPROCESS_SERVER): - host = None - if "host" in show_options: - host = show_options["host"] - del show_options["host"] - - port = None - if "port" in show_options: - port = show_options["port"] - del show_options["port"] - - if online_optim == OnlineOptim.SERVER: - class_to_instantiate = OnlineCallbackServer - elif online_optim == OnlineOptim.MULTIPROCESS_SERVER: - class_to_instantiate = OnlineCallbackMultiprocessServer - else: - raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") - - interface.options_common["iteration_callback"] = class_to_instantiate( - ocp, show_options=show_options, host=host, port=port - ) + to_call = OnlineCallbackMultiprocess + elif online_optim == OnlineOptim.SERVER: + to_call = OnlineCallbackServer + elif online_optim == OnlineOptim.MULTIPROCESS_SERVER: + to_call = OnlineCallbackMultiprocessServer else: - raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") + raise ValueError(f"online_optim {online_optim} is not implemented yet") + + interface.options_common["iteration_callback"] = to_call(ocp, **show_options) def generic_solve(interface, expand_during_shake_tree=False) -> dict: diff --git a/bioptim/misc/enums.py b/bioptim/misc/enums.py index eaa396af7..fd93b6d89 100644 --- a/bioptim/misc/enums.py +++ b/bioptim/misc/enums.py @@ -1,4 +1,5 @@ from enum import Enum, IntEnum, auto +import platform class PhaseDynamics(Enum): @@ -111,6 +112,17 @@ class OnlineOptim(Enum): SERVER = auto() MULTIPROCESS_SERVER = auto() + def get_default(self): + if self != OnlineOptim.DEFAULT: + return self + + if platform.system() == "Linux": + return OnlineOptim.MULTIPROCESS + elif platform.system() == "Windows": + return OnlineOptim.MULTIPROCESS_SERVER + else: + return None + class ControlType(Enum): """ diff --git a/bioptim/optimization/optimal_control_program.py b/bioptim/optimization/optimal_control_program.py index b72acb03d..b1a006d9f 100644 --- a/bioptim/optimization/optimal_control_program.py +++ b/bioptim/optimization/optimal_control_program.py @@ -689,6 +689,73 @@ def _finalize_penalties( self.update_parameter_objectives(parameter_objectives) return + def finalize_plot_phase_mappings(self): + """ + Finalize the plot phase mappings (if not already done) + + Parameters + ---------- + n_phases: int + The number of phases + """ + + for nlp in self.nlp: + if not nlp.plot: + return + + for key in nlp.plot: + if isinstance(nlp.plot[key], tuple): + nlp.plot[key] = nlp.plot[key][0] + + # This is the point where we can safely define node_idx of the plot + if nlp.plot[key].node_idx is None: + nlp.plot[key].node_idx = range(nlp.n_states_nodes) + + # If the number of subplots is not known, compute it + if nlp.plot[key].phase_mappings is None: + node_index = nlp.plot[key].node_idx[0] + nlp.states.node_index = node_index + nlp.states_dot.node_index = node_index + nlp.controls.node_index = node_index + nlp.algebraic_states.node_index = node_index + + # If multi-node penalties = None, stays zero + size_x = nlp.states.shape + size_u = nlp.controls.shape + size_p = nlp.parameters.shape + size_a = nlp.algebraic_states.shape + size_d = nlp.numerical_timeseries.shape + if "penalty" in nlp.plot[key].parameters: + penalty = nlp.plot[key].parameters["penalty"] + + # As stated in penalty_option, the last controller is always supposed to be the right one + casadi_function = ( + penalty.function[0] if penalty.function[0] is not None else penalty.function[-1] + ) + if casadi_function is not None: + size_x = casadi_function.size_in("x")[0] + size_u = casadi_function.size_in("u")[0] + size_p = casadi_function.size_in("p")[0] + size_a = casadi_function.size_in("a")[0] + size_d = casadi_function.size_in("d")[0] + + size = ( + nlp.plot[key] + .function( + 0, # t0 + np.zeros(self.n_phases), # phases_dt + node_index, # node_idx + np.zeros((size_x, 1)), # states + np.zeros((size_u, 1)), # controls + np.zeros((size_p, 1)), # parameters + np.zeros((size_a, 1)), # algebraic_states + np.zeros((size_d, 1)), # numerical_timeseries + **nlp.plot[key].parameters, # parameters + ) + .shape[0] + ) + nlp.plot[key].phase_mappings = BiMapping(to_first=range(size), to_second=range(size)) + @property def variables_vector(self): return OptimizationVectorHelper.vector(self) diff --git a/resources/plotting_server.py b/resources/plotting_server.py index 3ee6eb2da..b709a0cf8 100644 --- a/resources/plotting_server.py +++ b/resources/plotting_server.py @@ -1,11 +1,14 @@ """ -This file is an example of how to run a bioptim Online plotting server. That said, this is usually not the way to run a -bioptim server as it is easier to run it as an automatic multiprocess (default). This is achieved by setting -`show_options={"type": OnlineOptim.SERVER, "as_multiprocess": True}` in the solver options. -If set to False, then the plotting server is mandatory. +This file is an example of how to run a bioptim Online plotting server. Apart on Macos, this is usually not the way to run a +bioptim server as it is easier to run it as an automatic multiprocess. This is achieved by setting +`Solver.IPOPT(online_optim=OnlineOptim.MULTIPROCESS_SERVER)`. +If set to OnlineOptim.SERVER, then the plotting server is mandatory. Since the server runs usings sockets, it is possible to run the server on a different machine than the one running the optimization. This is useful when the optimization is run on a cluster and the plotting server is run on a local machine. + +On Macos, this server is necessary as it won't connect using multiprocess. One can simply run the current script on +another terminal to access the online graphs """ from bioptim import PlottingServer diff --git a/tests/shard5/test_plot_server.py b/tests/shard5/test_plot_server.py index af774b260..442726ad5 100644 --- a/tests/shard5/test_plot_server.py +++ b/tests/shard5/test_plot_server.py @@ -2,8 +2,9 @@ from bioptim.gui.online_callback_server import _serialize_xydata, _deserialize_xydata from bioptim.gui.plot import PlotOcp +from bioptim.gui.online_callback_server import _ResponseHeader from bioptim.optimization.optimization_vector import OptimizationVectorHelper -from casadi import DM, Function +from casadi import DM import numpy as np @@ -40,3 +41,24 @@ def test_serialize_deserialize(): else: for y_phase, deserialized_y_phase in zip(y_variable, deserialized_y_variable): assert np.allclose(y_phase, deserialized_y_phase) + + +def test_response_header(): + # Make sure all the response have the same length + response_len = _ResponseHeader.response_len() + for response in _ResponseHeader: + assert len(response) == response_len + # Make sure encoding provides a constant length + assert len(response.encode()) == response_len + + # Make sure equality works + assert _ResponseHeader.OK == _ResponseHeader.OK + assert _ResponseHeader.OK.value == _ResponseHeader.OK + assert _ResponseHeader.OK.encode().decode() == _ResponseHeader.OK + assert _ResponseHeader.OK == _ResponseHeader.OK.encode().decode() + assert not (_ResponseHeader.OK != _ResponseHeader.OK) + assert not (_ResponseHeader.OK.encode().decode() != _ResponseHeader.OK) + assert not (_ResponseHeader.OK != _ResponseHeader.OK.encode().decode()) + assert not (_ResponseHeader.OK.value == _ResponseHeader.OK.encode().decode()) + assert _ResponseHeader.OK != _ResponseHeader.NOK + assert _ResponseHeader.NOK == _ResponseHeader.NOK