From 5c1f043b24ad57a236d22e8a7ced5ff035794ea8 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Wed, 15 Jan 2025 17:36:51 +0800 Subject: [PATCH] Fix web ui log (#2924) --- swift/ui/app.py | 2 +- swift/ui/llm_eval/llm_eval.py | 4 +--- swift/ui/llm_eval/runtime.py | 4 ++-- swift/ui/llm_export/llm_export.py | 4 +--- swift/ui/llm_infer/llm_infer.py | 4 +--- swift/ui/llm_infer/runtime.py | 24 ++++++++++++++++++++---- swift/ui/llm_train/llm_train.py | 4 +--- swift/ui/llm_train/runtime.py | 20 +++++++++++++++++--- 8 files changed, 44 insertions(+), 22 deletions(-) diff --git a/swift/ui/app.py b/swift/ui/app.py index 4fc31e8a69..1b09cbb1b8 100644 --- a/swift/ui/app.py +++ b/swift/ui/app.py @@ -55,7 +55,7 @@ def run(self): LLMInfer.set_lang(lang) LLMExport.set_lang(lang) LLMEval.set_lang(lang) - with gr.Blocks(title='SWIFT WebUI') as app: + with gr.Blocks(title='SWIFT WebUI', theme=gr.themes.Base()) as app: try: _version = swift.__version__ except AttributeError: diff --git a/swift/ui/llm_eval/llm_eval.py b/swift/ui/llm_eval/llm_eval.py index 760da5cf1f..88c703a1a5 100644 --- a/swift/ui/llm_eval/llm_eval.py +++ b/swift/ui/llm_eval/llm_eval.py @@ -93,13 +93,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): base_tab.element('running_tasks').change( partial(EvalRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], - list(base_tab.valid_elements().values()) + [cls.element('log')], - cancels=EvalRuntime.log_event) + list(base_tab.valid_elements().values()) + [cls.element('log')]) EvalRuntime.element('kill_task').click( EvalRuntime.kill_task, [EvalRuntime.element('running_tasks')], [EvalRuntime.element('running_tasks')] + [EvalRuntime.element('log')], - cancels=[EvalRuntime.log_event], ) @classmethod diff --git a/swift/ui/llm_eval/runtime.py b/swift/ui/llm_eval/runtime.py index 7b91033b91..03c90b81b0 100644 --- a/swift/ui/llm_eval/runtime.py +++ b/swift/ui/llm_eval/runtime.py @@ -89,7 +89,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): gr.Button(elem_id='refresh_tasks', scale=1, variant='primary') gr.Button(elem_id='show_log', scale=1, variant='primary') gr.Button(elem_id='stop_show_log', scale=1) - gr.Button(elem_id='kill_task', scale=1) + gr.Button(elem_id='kill_task', scale=1, size='lg') with gr.Row(): gr.Textbox(elem_id='log', lines=6, visible=False) @@ -99,7 +99,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): cls.log_event = base_tab.element('show_log').click(cls.update_log, [], [cls.element('log')]).then( cls.wait, [base_tab.element('running_tasks')], [cls.element('log')], **concurrency_limit) - base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event) + base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], []) base_tab.element('refresh_tasks').click( cls.refresh_tasks, diff --git a/swift/ui/llm_export/llm_export.py b/swift/ui/llm_export/llm_export.py index 4d6cbfe9fd..4a22627dd6 100644 --- a/swift/ui/llm_export/llm_export.py +++ b/swift/ui/llm_export/llm_export.py @@ -91,13 +91,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): base_tab.element('running_tasks').change( partial(ExportRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], - list(base_tab.valid_elements().values()) + [cls.element('log')], - cancels=ExportRuntime.log_event) + list(base_tab.valid_elements().values()) + [cls.element('log')]) ExportRuntime.element('kill_task').click( ExportRuntime.kill_task, [ExportRuntime.element('running_tasks')], [ExportRuntime.element('running_tasks')] + [ExportRuntime.element('log')], - cancels=[ExportRuntime.log_event], ) @classmethod diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py index 1a5e9548ca..141533424c 100644 --- a/swift/ui/llm_infer/llm_infer.py +++ b/swift/ui/llm_infer/llm_infer.py @@ -179,13 +179,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): base_tab.element('running_tasks').change( partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], - list(cls.valid_elements().values()) + [cls.element('log')], - cancels=Runtime.log_event) + list(cls.valid_elements().values()) + [cls.element('log')]) Runtime.element('kill_task').click( Runtime.kill_task, [Runtime.element('running_tasks')], [Runtime.element('running_tasks')] + [Runtime.element('log')], - cancels=[Runtime.log_event], ) @classmethod diff --git a/swift/ui/llm_infer/runtime.py b/swift/ui/llm_infer/runtime.py index e8b9fcfefe..29459c8d46 100644 --- a/swift/ui/llm_infer/runtime.py +++ b/swift/ui/llm_infer/runtime.py @@ -25,7 +25,7 @@ class Runtime(BaseUI): cmd = 'deploy' - log_event = None + log_event = {} locale_dict = { 'runtime_tab': { @@ -106,10 +106,12 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): concurrency_limit = {} if version.parse(gr.__version__) >= version.parse('4.0.0'): concurrency_limit = {'concurrency_limit': 5} - cls.log_event = base_tab.element('show_log').click(cls.update_log, [], [cls.element('log')]).then( - cls.wait, [base_tab.element('running_tasks')], [cls.element('log')], **concurrency_limit) + base_tab.element('show_log').click(cls.update_log, [], + [cls.element('log')]).then(cls.wait, + [base_tab.element('running_tasks')], + [cls.element('log')], **concurrency_limit) - base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event) + base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], []) base_tab.element('refresh_tasks').click( cls.refresh_tasks, @@ -117,6 +119,13 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): [base_tab.element('running_tasks')], ) + @classmethod + def break_log_event(cls, task): + if not task: + return + pid, all_args = cls.parse_info_from_cmdline(task) + cls.log_event[all_args['log_file']] = True + @classmethod def update_log(cls): return gr.update(visible=True) @@ -127,6 +136,7 @@ def wait(cls, task): return [None] _, args = cls.parse_info_from_cmdline(task) log_file = args['log_file'] + cls.log_event[log_file] = False offset = 0 latest_data = '' lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50))) @@ -145,6 +155,10 @@ def wait(cls, task): if fail_cnt > 50: break + if cls.log_event.get(log_file, False): + cls.log_event[log_file] = False + break + if '\n' not in latest_data: continue latest_lines = latest_data.split('\n') @@ -241,6 +255,7 @@ def kill_task(cls, task): else: os.system(f'pkill -9 -f {log_file}') time.sleep(1) + cls.break_log_event(task) return [cls.refresh_tasks()] + [gr.update(value=None)] @classmethod @@ -267,4 +282,5 @@ def task_changed(cls, task, base_tab): ret.append(gr.update(value=arg)) else: ret.append(gr.update()) + cls.break_log_event(task) return ret + [gr.update(value=None)] diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py index 706fe15766..34ec4e8b32 100644 --- a/swift/ui/llm_train/llm_train.py +++ b/swift/ui/llm_train/llm_train.py @@ -264,13 +264,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): base_tab.element('running_tasks').change( partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], - list(base_tab.valid_elements().values()) + [cls.element('log')] + Runtime.all_plots, - cancels=Runtime.log_event) + list(base_tab.valid_elements().values()) + [cls.element('log')] + Runtime.all_plots) Runtime.element('kill_task').click( Runtime.kill_task, [Runtime.element('running_tasks')], [Runtime.element('running_tasks')] + [Runtime.element('log')] + Runtime.all_plots, - cancels=[Runtime.log_event], ).then(Runtime.reset, [], [Runtime.element('logging_dir')] + [Hyper.element('output_dir')]) @classmethod diff --git a/swift/ui/llm_train/runtime.py b/swift/ui/llm_train/runtime.py index 3e186d6b9d..71cf828fb1 100644 --- a/swift/ui/llm_train/runtime.py +++ b/swift/ui/llm_train/runtime.py @@ -30,7 +30,7 @@ class Runtime(BaseUI): all_plots = None - log_event = None + log_event = {} sft_plot = [ { @@ -253,13 +253,13 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): concurrency_limit = {} if version.parse(gr.__version__) >= version.parse('4.0.0'): concurrency_limit = {'concurrency_limit': 5} - cls.log_event = base_tab.element('show_log').click( + base_tab.element('show_log').click( Runtime.update_log, [base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots).then( Runtime.wait, [base_tab.element('logging_dir'), base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots, **concurrency_limit) - base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event) + base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], []) base_tab.element('start_tb').click( Runtime.start_tb, @@ -315,6 +315,7 @@ def wait(cls, logging_dir, task): if not logging_dir: return [None] + Runtime.plot(task) log_file = os.path.join(logging_dir, 'run.log') + cls.log_event[logging_dir] = False offset = 0 latest_data = '' lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50))) @@ -333,6 +334,10 @@ def wait(cls, logging_dir, task): if fail_cnt > 50: break + if cls.log_event.get(logging_dir, False): + cls.log_event[logging_dir] = False + break + if '\n' not in latest_data: continue latest_lines = latest_data.split('\n') @@ -355,6 +360,13 @@ def wait(cls, logging_dir, task): except IOError: pass + @classmethod + def break_log_event(cls, task): + if not task: + return + pid, all_args = Runtime.parse_info_from_cmdline(task) + cls.log_event[all_args['logging_dir']] = True + @classmethod def show_log(cls, logging_dir): webbrowser.open('file://' + os.path.join(logging_dir, 'run.log'), new=2) @@ -472,6 +484,7 @@ def kill_task(task): else: os.system(f'pkill -9 -f {output_dir}') time.sleep(1) + Runtime.break_log_event(task) return [Runtime.refresh_tasks()] + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1) @staticmethod @@ -495,6 +508,7 @@ def task_changed(task, base_tab): ret.append(gr.update(value=arg)) else: ret.append(gr.update()) + Runtime.break_log_event(task) return ret + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1) @staticmethod