diff --git a/pynest/nest/server/hl_api_server.py b/pynest/nest/server/hl_api_server.py index ff0bd2f361..eb75763757 100644 --- a/pynest/nest/server/hl_api_server.py +++ b/pynest/nest/server/hl_api_server.py @@ -36,8 +36,7 @@ from flask import Flask, jsonify, request from flask.logging import default_handler from flask_cors import CORS -from werkzeug.exceptions import abort -from werkzeug.wrappers import Response +from nest.lib.hl_api_exceptions import NESTError # This ensures that the logging information shows up in the console running the server, # even when Flask's event loop is running. @@ -189,41 +188,36 @@ def index(): def do_exec(args, kwargs): - try: - source_code = kwargs.get("source", "") - source_cleaned = clean_code(source_code) - - locals_ = dict() - response = dict() - if RESTRICTION_DISABLED: - with Capturing() as stdout: - globals_ = globals().copy() - globals_.update(get_modules_from_env()) - exec(source_cleaned, globals_, locals_) - if len(stdout) > 0: - response["stdout"] = "\n".join(stdout) - else: - code = RestrictedPython.compile_restricted(source_cleaned, "", "exec") # noqa - globals_ = get_restricted_globals() + source_code = kwargs.get("source", "") + source_cleaned = clean_code(source_code) + + locals_ = dict() + response = dict() + if RESTRICTION_DISABLED: + with Capturing() as stdout: + globals_ = globals().copy() globals_.update(get_modules_from_env()) - exec(code, globals_, locals_) - if "_print" in locals_: - response["stdout"] = "".join(locals_["_print"].txt) - - if "return" in kwargs: - if isinstance(kwargs["return"], list): - data = dict() - for variable in kwargs["return"]: - data[variable] = locals_.get(variable, None) - else: - data = locals_.get(kwargs["return"], None) - response["data"] = nest.serialize_data(data) - return response + get_or_error(exec)(source_cleaned, globals_, locals_) + if len(stdout) > 0: + response["stdout"] = "\n".join(stdout) + else: + code = RestrictedPython.compile_restricted(source_cleaned, "", "exec") # noqa + globals_ = get_restricted_globals() + globals_.update(get_modules_from_env()) + get_or_error(exec)(code, globals_, locals_) + if "_print" in locals_: + response["stdout"] = "".join(locals_["_print"].txt) + + if "return" in kwargs: + if isinstance(kwargs["return"], list): + data = dict() + for variable in kwargs["return"]: + data[variable] = locals_.get(variable, None) + else: + data = locals_.get(kwargs["return"], None) - except Exception as e: - for line in traceback.format_exception(*sys.exc_info()): - print(line, flush=True) - flask.abort(EXCEPTION_ERROR_STATUS, str(e)) + response["data"] = get_or_error(nest.serialize_data)(data) + return response def log(call_name, msg): @@ -336,10 +330,43 @@ def __exit__(self, *args): sys.stdout = self._stdout +class ErrorHandler(Exception): + status_code = 400 + lineno = -1 + + def __init__(self, message: str, lineno: int = None, status_code: int = None, payload=None): + super().__init__() + self.message = message + if status_code is not None: + self.status_code = status_code + if lineno is not None: + self.lineno = lineno + self.payload = payload + + def to_dict(self): + rv = dict(self.payload or ()) + rv["message"] = self.message + if self.lineno != -1: + rv["lineNumber"] = self.lineno + return rv + + +# https://flask.palletsprojects.com/en/2.3.x/errorhandling/ +@app.errorhandler(ErrorHandler) +def error_handler(e): + return jsonify(e.to_dict()), e.status_code + + +# It comments lines starting with 'import' or 'from' otherwise the line number of error would be wrong. def clean_code(source): codes = source.split("\n") - code_cleaned = filter(lambda code: not (code.startswith("import") or code.startswith("from")), codes) # noqa - return "\n".join(code_cleaned) + codes_cleaned = [] # noqa + for code in codes: + if code.startswith("import") or code.startswith("from"): + codes_cleaned.append("#" + code) + else: + codes_cleaned.append(code) + return "\n".join(codes_cleaned) def get_arguments(request): @@ -368,6 +395,16 @@ def get_arguments(request): return list(args), kwargs +def get_lineno(err, tb_idx): + lineno = -1 + if hasattr(err, "lineno") and err.lineno is not None: + lineno = err.lineno + else: + tb = sys.exc_info()[2] + lineno = traceback.extract_tb(tb)[tb_idx][1] + return lineno + + def get_modules_from_env(): """Get modules from environment variable NEST_SERVER_MODULES. @@ -397,13 +434,34 @@ def get_modules_from_env(): def get_or_error(func): """Wrapper to get data and status.""" - def func_wrapper(call, args, kwargs): + def func_wrapper(call, *args, **kwargs): try: - return func(call, args, kwargs) - except Exception as e: - for line in traceback.format_exception(*sys.exc_info()): - print(line, flush=True) - flask.abort(EXCEPTION_ERROR_STATUS, str(e)) + return func(call, *args, **kwargs) + + except NESTError as err: + error_class = err.errorname + " (NESTError)" + detail = err.errormessage + lineno = get_lineno(err, 1) + + except (KeyError, SyntaxError, TypeError, ValueError) as err: + error_class = err.__class__.__name__ + detail = err.args[0] + lineno = get_lineno(err, 1) + + except Exception as err: + error_class = err.__class__.__name__ + detail = err.args[0] + lineno = get_lineno(err, -1) + + for line in traceback.format_exception(*sys.exc_info()): + print(line, flush=True) + + if lineno == -1: + message = "%s: %s" % (error_class, detail) + else: + message = "%s at line %d: %s" % (error_class, lineno, detail) + + raise ErrorHandler(message, lineno) return func_wrapper