Skip to content

Commit

Permalink
Fix web ui log (#2924)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jan 15, 2025
1 parent a9764ce commit 5c1f043
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 22 deletions.
2 changes: 1 addition & 1 deletion swift/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run(self):
LLMInfer.set_lang(lang)
LLMExport.set_lang(lang)
LLMEval.set_lang(lang)
with gr.Blocks(title='SWIFT WebUI') as app:
with gr.Blocks(title='SWIFT WebUI', theme=gr.themes.Base()) as app:
try:
_version = swift.__version__
except AttributeError:
Expand Down
4 changes: 1 addition & 3 deletions swift/ui/llm_eval/llm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):

base_tab.element('running_tasks').change(
partial(EvalRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
list(base_tab.valid_elements().values()) + [cls.element('log')],
cancels=EvalRuntime.log_event)
list(base_tab.valid_elements().values()) + [cls.element('log')])
EvalRuntime.element('kill_task').click(
EvalRuntime.kill_task,
[EvalRuntime.element('running_tasks')],
[EvalRuntime.element('running_tasks')] + [EvalRuntime.element('log')],
cancels=[EvalRuntime.log_event],
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions swift/ui/llm_eval/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
gr.Button(elem_id='refresh_tasks', scale=1, variant='primary')
gr.Button(elem_id='show_log', scale=1, variant='primary')
gr.Button(elem_id='stop_show_log', scale=1)
gr.Button(elem_id='kill_task', scale=1)
gr.Button(elem_id='kill_task', scale=1, size='lg')
with gr.Row():
gr.Textbox(elem_id='log', lines=6, visible=False)

Expand All @@ -99,7 +99,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
cls.log_event = base_tab.element('show_log').click(cls.update_log, [], [cls.element('log')]).then(
cls.wait, [base_tab.element('running_tasks')], [cls.element('log')], **concurrency_limit)

base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event)
base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])

base_tab.element('refresh_tasks').click(
cls.refresh_tasks,
Expand Down
4 changes: 1 addition & 3 deletions swift/ui/llm_export/llm_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):

base_tab.element('running_tasks').change(
partial(ExportRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
list(base_tab.valid_elements().values()) + [cls.element('log')],
cancels=ExportRuntime.log_event)
list(base_tab.valid_elements().values()) + [cls.element('log')])
ExportRuntime.element('kill_task').click(
ExportRuntime.kill_task,
[ExportRuntime.element('running_tasks')],
[ExportRuntime.element('running_tasks')] + [ExportRuntime.element('log')],
cancels=[ExportRuntime.log_event],
)

@classmethod
Expand Down
4 changes: 1 addition & 3 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):

base_tab.element('running_tasks').change(
partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
list(cls.valid_elements().values()) + [cls.element('log')],
cancels=Runtime.log_event)
list(cls.valid_elements().values()) + [cls.element('log')])
Runtime.element('kill_task').click(
Runtime.kill_task,
[Runtime.element('running_tasks')],
[Runtime.element('running_tasks')] + [Runtime.element('log')],
cancels=[Runtime.log_event],
)

@classmethod
Expand Down
24 changes: 20 additions & 4 deletions swift/ui/llm_infer/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Runtime(BaseUI):

cmd = 'deploy'

log_event = None
log_event = {}

locale_dict = {
'runtime_tab': {
Expand Down Expand Up @@ -106,17 +106,26 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
concurrency_limit = {}
if version.parse(gr.__version__) >= version.parse('4.0.0'):
concurrency_limit = {'concurrency_limit': 5}
cls.log_event = base_tab.element('show_log').click(cls.update_log, [], [cls.element('log')]).then(
cls.wait, [base_tab.element('running_tasks')], [cls.element('log')], **concurrency_limit)
base_tab.element('show_log').click(cls.update_log, [],
[cls.element('log')]).then(cls.wait,
[base_tab.element('running_tasks')],
[cls.element('log')], **concurrency_limit)

base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event)
base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])

base_tab.element('refresh_tasks').click(
cls.refresh_tasks,
[base_tab.element('running_tasks')],
[base_tab.element('running_tasks')],
)

@classmethod
def break_log_event(cls, task):
if not task:
return
pid, all_args = cls.parse_info_from_cmdline(task)
cls.log_event[all_args['log_file']] = True

@classmethod
def update_log(cls):
return gr.update(visible=True)
Expand All @@ -127,6 +136,7 @@ def wait(cls, task):
return [None]
_, args = cls.parse_info_from_cmdline(task)
log_file = args['log_file']
cls.log_event[log_file] = False
offset = 0
latest_data = ''
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
Expand All @@ -145,6 +155,10 @@ def wait(cls, task):
if fail_cnt > 50:
break

if cls.log_event.get(log_file, False):
cls.log_event[log_file] = False
break

if '\n' not in latest_data:
continue
latest_lines = latest_data.split('\n')
Expand Down Expand Up @@ -241,6 +255,7 @@ def kill_task(cls, task):
else:
os.system(f'pkill -9 -f {log_file}')
time.sleep(1)
cls.break_log_event(task)
return [cls.refresh_tasks()] + [gr.update(value=None)]

@classmethod
Expand All @@ -267,4 +282,5 @@ def task_changed(cls, task, base_tab):
ret.append(gr.update(value=arg))
else:
ret.append(gr.update())
cls.break_log_event(task)
return ret + [gr.update(value=None)]
4 changes: 1 addition & 3 deletions swift/ui/llm_train/llm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):

base_tab.element('running_tasks').change(
partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
list(base_tab.valid_elements().values()) + [cls.element('log')] + Runtime.all_plots,
cancels=Runtime.log_event)
list(base_tab.valid_elements().values()) + [cls.element('log')] + Runtime.all_plots)
Runtime.element('kill_task').click(
Runtime.kill_task,
[Runtime.element('running_tasks')],
[Runtime.element('running_tasks')] + [Runtime.element('log')] + Runtime.all_plots,
cancels=[Runtime.log_event],
).then(Runtime.reset, [], [Runtime.element('logging_dir')] + [Hyper.element('output_dir')])

@classmethod
Expand Down
20 changes: 17 additions & 3 deletions swift/ui/llm_train/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Runtime(BaseUI):

all_plots = None

log_event = None
log_event = {}

sft_plot = [
{
Expand Down Expand Up @@ -253,13 +253,13 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
concurrency_limit = {}
if version.parse(gr.__version__) >= version.parse('4.0.0'):
concurrency_limit = {'concurrency_limit': 5}
cls.log_event = base_tab.element('show_log').click(
base_tab.element('show_log').click(
Runtime.update_log, [base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots).then(
Runtime.wait, [base_tab.element('logging_dir'),
base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots,
**concurrency_limit)

base_tab.element('stop_show_log').click(lambda: None, cancels=cls.log_event)
base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])

base_tab.element('start_tb').click(
Runtime.start_tb,
Expand Down Expand Up @@ -315,6 +315,7 @@ def wait(cls, logging_dir, task):
if not logging_dir:
return [None] + Runtime.plot(task)
log_file = os.path.join(logging_dir, 'run.log')
cls.log_event[logging_dir] = False
offset = 0
latest_data = ''
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
Expand All @@ -333,6 +334,10 @@ def wait(cls, logging_dir, task):
if fail_cnt > 50:
break

if cls.log_event.get(logging_dir, False):
cls.log_event[logging_dir] = False
break

if '\n' not in latest_data:
continue
latest_lines = latest_data.split('\n')
Expand All @@ -355,6 +360,13 @@ def wait(cls, logging_dir, task):
except IOError:
pass

@classmethod
def break_log_event(cls, task):
if not task:
return
pid, all_args = Runtime.parse_info_from_cmdline(task)
cls.log_event[all_args['logging_dir']] = True

@classmethod
def show_log(cls, logging_dir):
webbrowser.open('file://' + os.path.join(logging_dir, 'run.log'), new=2)
Expand Down Expand Up @@ -472,6 +484,7 @@ def kill_task(task):
else:
os.system(f'pkill -9 -f {output_dir}')
time.sleep(1)
Runtime.break_log_event(task)
return [Runtime.refresh_tasks()] + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)

@staticmethod
Expand All @@ -495,6 +508,7 @@ def task_changed(task, base_tab):
ret.append(gr.update(value=arg))
else:
ret.append(gr.update())
Runtime.break_log_event(task)
return ret + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)

@staticmethod
Expand Down

0 comments on commit 5c1f043

Please sign in to comment.