From 6de7678544c9a9d495a77158d7e5c1e82d3fec54 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 16 Jul 2023 00:56:03 +0800 Subject: [PATCH] update webui --- src/glmtuner/__init__.py | 2 +- src/glmtuner/extras/constants.py | 42 +++------------- src/glmtuner/hparams/model_args.py | 6 +-- src/glmtuner/webui/chat.py | 39 ++++++++------- src/glmtuner/webui/common.py | 45 ++++++++--------- src/glmtuner/webui/components/data.py | 3 +- src/glmtuner/webui/components/eval.py | 30 +++++++---- src/glmtuner/webui/components/infer.py | 25 +++++----- src/glmtuner/webui/components/model.py | 33 ++++++------ src/glmtuner/webui/components/sft.py | 39 +++++++-------- src/glmtuner/webui/interface.py | 11 ++-- src/glmtuner/webui/runner.py | 69 +++++++++++++------------- src/glmtuner/webui/utils.py | 12 ++--- 13 files changed, 168 insertions(+), 188 deletions(-) diff --git a/src/glmtuner/__init__.py b/src/glmtuner/__init__.py index d6960f6..de0171e 100644 --- a/src/glmtuner/__init__.py +++ b/src/glmtuner/__init__.py @@ -4,4 +4,4 @@ from glmtuner.webui import create_ui -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/src/glmtuner/extras/constants.py b/src/glmtuner/extras/constants.py index c5d7c02..71d1de1 100644 --- a/src/glmtuner/extras/constants.py +++ b/src/glmtuner/extras/constants.py @@ -6,41 +6,11 @@ LAYERNORM_NAMES = ["layernorm"] -SUPPORTED_MODEL_LIST = [ - { - "name": "chatglm-6b", - "pretrained_model_name": "THUDM/chatglm-6b", - "local_model_path": None, - "provides": "ChatGLMLLMChain" +SUPPORTED_MODELS = { + "ChatGLM-6B": { + "hf_path": "THUDM/chatglm-6b" }, - { - "name": "chatglm2-6b", - "pretrained_model_name": "THUDM/chatglm2-6b", - "local_model_path": None, - "provides": "ChatGLMLLMChain" - }, - { - "name": "chatglm-6b-int8", - "pretrained_model_name": "THUDM/chatglm-6b-int8", - "local_model_path": None, - "provides": "ChatGLMLLMChain" - }, - { - "name": "chatglm2-6b-int8", - "pretrained_model_name": "THUDM/chatglm2-6b-int8", - "local_model_path": None, - "provides": "ChatGLMLLMChain" - }, - { - "name": "chatglm-6b-int4", - "pretrained_model_name": "THUDM/chatglm-6b-int4", - "local_model_path": None, - "provides": "ChatGLMLLMChain" - }, - { - "name": "chatglm2-6b-int4", - "pretrained_model_name": "THUDM/chatglm2-6b-int4", - "local_model_path": None, - "provides": "ChatGLMLLMChain" + "ChatGLM2-6B": { + "hf_path": "THUDM/chatglm2-6b" } -] \ No newline at end of file +} diff --git a/src/glmtuner/hparams/model_args.py b/src/glmtuner/hparams/model_args.py index 5c880ce..950420b 100644 --- a/src/glmtuner/hparams/model_args.py +++ b/src/glmtuner/hparams/model_args.py @@ -70,15 +70,11 @@ class ModelArguments: ) def __post_init__(self): - if self.checkpoint_dir == "": + if not self.checkpoint_dir: self.checkpoint_dir = None - # if base model is already quantization version, ignore quantization_bit config - if self.quantization_bit == "" or "int" in self.model_name_or_path: - self.quantization_bit = None if self.checkpoint_dir is not None: # support merging lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] if self.quantization_bit is not None: - self.quantization_bit = int(self.quantization_bit) assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." diff --git a/src/glmtuner/webui/chat.py b/src/glmtuner/webui/chat.py index 6a160f9..d7b73af 100644 --- a/src/glmtuner/webui/chat.py +++ b/src/glmtuner/webui/chat.py @@ -2,6 +2,7 @@ from typing import List, Tuple from glmtuner.chat.stream_chat import ChatModel +from glmtuner.extras.constants import SUPPORTED_MODELS from glmtuner.extras.misc import torch_gc from glmtuner.hparams import GeneratingArguments from glmtuner.tuner import get_infer_args @@ -15,30 +16,34 @@ def __init__(self): self.tokenizer = None self.generating_args = GeneratingArguments() - def load_model(self, base_model: str, model_path: str, checkpoints: list, quantization_bit: str): + def load_model(self, model_name: str, model_path: str, checkpoints: list, quantization_bit: str): if self.model is not None: yield "You have loaded a model, please unload it first." return - if not base_model: + if not model_name: yield "Please select a model." return - if get_save_dir(base_model) and checkpoints: - checkpoint_dir = ",".join( - [os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) + if model_path: + if not os.path.isdir(model_path): + return None, "Cannot find model directory in local disk.", None, None + model_name_or_path = model_path + elif model_name in SUPPORTED_MODELS: # TODO: use list in gr.State + model_name_or_path = SUPPORTED_MODELS[model_name]["hf_path"] + else: + return None, "Invalid model.", None, None + + if checkpoints: + checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints]) else: checkpoint_dir = None yield "Loading model..." - if model_path: - model_name_or_path = model_path - else: - model_name_or_path = base_model args = dict( model_name_or_path=model_name_or_path, checkpoint_dir=checkpoint_dir, - quantization_bit=quantization_bit + quantization_bit=int(quantization_bit) if quantization_bit else None ) super().__init__(*get_infer_args(args)) @@ -52,13 +57,13 @@ def unload_model(self): yield "Model unloaded, please load a model first." def predict( - self, - chatbot: List[Tuple[str, str]], - query: str, - history: List[Tuple[str, str]], - max_length: int, - top_p: float, - temperature: float + self, + chatbot: List[Tuple[str, str]], + query: str, + history: List[Tuple[str, str]], + max_length: int, + top_p: float, + temperature: float ): chatbot.append([query, ""]) response = "" diff --git a/src/glmtuner/webui/common.py b/src/glmtuner/webui/common.py index 54fcf70..4ccacea 100644 --- a/src/glmtuner/webui/common.py +++ b/src/glmtuner/webui/common.py @@ -1,45 +1,42 @@ -import codecs import json import os -from typing import List, Tuple +from typing import Dict, List, Tuple import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -CACHE_DIR = "cache" # to save models +CACHE_DIR = "cache" DATA_DIR = "data" SAVE_DIR = "saves" -TEMP_USE_CONFIG = "tmp.use.config" +USER_CONFIG = "user.config" -def get_temp_use_config_path(): - return os.path.join(SAVE_DIR, TEMP_USE_CONFIG) +def get_config_path(): + return os.path.join(CACHE_DIR, USER_CONFIG) -def load_temp_use_config(): - if not os.path.exists(get_temp_use_config_path()): +def load_config() -> Dict[str, str]: + if not os.path.exists(get_config_path()): return {} - with codecs.open(get_temp_use_config_path()) as f: + + with open(get_config_path(), "r", encoding="utf-8") as f: try: user_config = json.load(f) return user_config - except Exception as e: + except: return {} -def save_temp_use_config(user_config: dict): - with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f: - json.dump(f, user_config, ensure_ascii=False) - - -def save_model_config(model_name: str, model_path: str): - with codecs.open(get_temp_use_config_path(), "w", encoding="utf-8") as f: - json.dump({"model_name": model_name, "model_path": model_path}, f, ensure_ascii=False) +def save_config(model_name: str, model_path: str) -> None: + os.makedirs(CACHE_DIR, exist_ok=True) + user_config = dict(model_name=model_name, model_path=model_path) + with open(get_config_path(), "w", encoding="utf-8") as f: + json.dump(user_config, f, ensure_ascii=False) def get_save_dir(model_name: str) -> str: - return os.path.join(SAVE_DIR, model_name.split("/")[-1]) + return os.path.join(SAVE_DIR, os.path.split(model_name)[-1]) def add_model(model_list: list, model_name: str, model_path: str) -> Tuple[list, str, str]: @@ -62,11 +59,11 @@ def list_checkpoints(model_name: str) -> dict: if save_dir and os.path.isdir(save_dir): for checkpoint in os.listdir(save_dir): if ( - os.path.isdir(os.path.join(save_dir, checkpoint)) - and any([ - os.path.isfile(os.path.join(save_dir, checkpoint, name)) - for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) - ]) + os.path.isdir(os.path.join(save_dir, checkpoint)) + and any([ + os.path.isfile(os.path.join(save_dir, checkpoint, name)) + for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) + ]) ): checkpoints.append(checkpoint) return gr.update(value=[], choices=checkpoints) diff --git a/src/glmtuner/webui/components/data.py b/src/glmtuner/webui/components/data.py index e6cc8fe..d95c1a4 100644 --- a/src/glmtuner/webui/components/data.py +++ b/src/glmtuner/webui/components/data.py @@ -1,7 +1,6 @@ -from typing import Tuple - import gradio as gr from gradio.components import Component +from typing import Tuple def create_preview_box() -> Tuple[Component, Component, Component]: diff --git a/src/glmtuner/webui/components/eval.py b/src/glmtuner/webui/components/eval.py index ebaa837..7c82278 100644 --- a/src/glmtuner/webui/components/eval.py +++ b/src/glmtuner/webui/components/eval.py @@ -2,14 +2,24 @@ from gradio.components import Component from glmtuner.webui.common import list_datasets +from glmtuner.webui.components.data import create_preview_box from glmtuner.webui.runner import Runner +from glmtuner.webui.utils import can_preview, get_preview -def create_eval_tab(base_model: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None: +def create_eval_tab(model_name: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None: with gr.Row(): - dataset = gr.Dropdown( - label="Dataset", info="The name of dataset(s).", choices=list_datasets(), multiselect=True, interactive=True - ) + dataset = gr.Dropdown(label="Dataset", choices=list_datasets(), multiselect=True, interactive=True, scale=4) + preview_btn = gr.Button("Preview", interactive=False, scale=1) + + preview_box, preview_count, preview_samples = create_preview_box() + + dataset.change(can_preview, [dataset], [preview_btn]) + preview_btn.click( + get_preview, [dataset], [preview_count, preview_samples] + ).then( + lambda: gr.update(visible=True), outputs=[preview_box] + ) with gr.Row(): max_samples = gr.Textbox( @@ -18,17 +28,17 @@ def create_eval_tab(base_model: Component, model_path: Component, checkpoints: C per_device_eval_batch_size = gr.Slider( label="Batch size", value=8, minimum=1, maximum=128, step=1, info="Eval batch size.", interactive=True ) - quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit") + quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Quantize model to 4/8-bit mode.") with gr.Row(): - start = gr.Button("Start evaluation") - stop = gr.Button("Abort") + start_btn = gr.Button("Start evaluation") + stop_btn = gr.Button("Abort") output = gr.Markdown(value="Ready") - start.click( + start_btn.click( runner.run_eval, - [base_model, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, quantization_bit], + [model_name, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, quantization_bit], [output] ) - stop.click(runner.set_abort, queue=False) + stop_btn.click(runner.set_abort, queue=False) diff --git a/src/glmtuner/webui/components/infer.py b/src/glmtuner/webui/components/infer.py index 47bfb9f..eb9a85b 100644 --- a/src/glmtuner/webui/components/infer.py +++ b/src/glmtuner/webui/components/infer.py @@ -16,25 +16,26 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com query = gr.Textbox(show_label=False, placeholder="Input...", lines=10) with gr.Column(min_width=32, scale=1): - submit = gr.Button("Submit", variant="primary") + submit_btn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): - clear = gr.Button("Clear History") + clear_btn = gr.Button("Clear History") max_length = gr.Slider( - 10, 2048, value=chat_model.generating_args.max_length, step=1.0, label="Maximum length", - interactive=True + 10, 2048, value=chat_model.generating_args.max_length, step=1.0, + label="Maximum length", interactive=True ) top_p = gr.Slider( - 0, 1, value=chat_model.generating_args.top_p, step=0.01, label="Top P", interactive=True + 0, 1, value=chat_model.generating_args.top_p, step=0.01, + label="Top P", interactive=True ) temperature = gr.Slider( - 0, 1.5, value=chat_model.generating_args.temperature, step=0.01, label="Temperature", - interactive=True + 0, 1.5, value=chat_model.generating_args.temperature, step=0.01, + label="Temperature", interactive=True ) history = gr.State([]) - submit.click( + submit_btn.click( chat_model.predict, [chatbot, query, history, max_length, top_p, temperature], [chatbot, history], @@ -43,12 +44,12 @@ def create_chat_box(chat_model: WebChatModel) -> Tuple[Component, Component, Com lambda: gr.update(value=""), outputs=[query] ) - clear.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) + clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) return chat_box, chatbot, history -def create_infer_tab(base_model: Component, model_path: Component, checkpoints: Component) -> None: +def create_infer_tab(model_name: Component, model_path: Component, checkpoints: Component) -> None: info_box = gr.Markdown(value="Model unloaded, please load a model first.") chat_model = WebChatModel() @@ -57,10 +58,10 @@ def create_infer_tab(base_model: Component, model_path: Component, checkpoints: with gr.Row(): load_btn = gr.Button("Load model") unload_btn = gr.Button("Unload model") - quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit") + quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Quantize model to 4/8-bit mode.") load_btn.click( - chat_model.load_model, [base_model, model_path, checkpoints, quantization_bit], [info_box] + chat_model.load_model, [model_name, model_path, checkpoints, quantization_bit], [info_box] ).then( lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] ) diff --git a/src/glmtuner/webui/components/model.py b/src/glmtuner/webui/components/model.py index bcc9654..6ba2dc2 100644 --- a/src/glmtuner/webui/components/model.py +++ b/src/glmtuner/webui/components/model.py @@ -3,28 +3,31 @@ import gradio as gr from gradio.components import Component -from glmtuner.extras.constants import SUPPORTED_MODEL_LIST -from glmtuner.webui.common import list_checkpoints, load_temp_use_config, save_model_config +from glmtuner.extras.constants import SUPPORTED_MODELS +from glmtuner.webui.common import list_checkpoints, load_config, save_config def create_model_tab() -> Tuple[Component, Component, Component]: - user_config = load_temp_use_config() - gr_state = gr.State([]) # gr.State does not accept a dict + user_config = load_config() + available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] with gr.Row(): - model_name = gr.Dropdown([model["pretrained_model_name"] for model in SUPPORTED_MODEL_LIST] + ["custom"], - label="Base Model", info="Model Version of ChatGLM", - value=user_config.get("model_name")) - model_path = gr.Textbox(lines=1, label="Local model path(Optional)", - info="The absolute path of the directory where the local model file is located", - value=user_config.get("model_path")) + model_name = gr.Dropdown(choices=available_models, label="Model", value=user_config.get("model_name", None)) + model_path = gr.Textbox( + label="Local path (Optional)", value=user_config.get("model_path", None), + info="The absolute path of the directory where the local model file is located." + ) with gr.Row(): checkpoints = gr.Dropdown(label="Checkpoints", multiselect=True, interactive=True, scale=5) - refresh = gr.Button("Refresh checkpoints", scale=1) - - model_name.change(list_checkpoints, [model_name], [checkpoints]) - model_path.change(save_model_config, [model_name, model_path]) - refresh.click(list_checkpoints, [model_name], [checkpoints]) + refresh_btn = gr.Button("Refresh checkpoints", scale=1) + + model_name.change( + list_checkpoints, [model_name], [checkpoints] + ).then( # TODO: save list + lambda: gr.update(value=""), outputs=[model_path] + ) + model_path.change(save_config, [model_name, model_path]) + refresh_btn.click(list_checkpoints, [model_name], [checkpoints]) return model_name, model_path, checkpoints diff --git a/src/glmtuner/webui/components/sft.py b/src/glmtuner/webui/components/sft.py index 79b6362..bd1009d 100644 --- a/src/glmtuner/webui/components/sft.py +++ b/src/glmtuner/webui/components/sft.py @@ -8,20 +8,19 @@ from glmtuner.webui.utils import can_preview, get_preview, get_time, gen_plot -def create_sft_tab(base_model: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None: +def create_sft_tab(model_name: Component, model_path: Component, checkpoints: Component, runner: Runner) -> None: with gr.Row(): finetuning_type = gr.Dropdown( - label="Finetuning method", value="lora", choices=["full", "freeze", "p_tuning", "lora"], interactive=True + label="Finetuning method", value="lora", + choices=["full", "freeze", "p_tuning", "lora"], interactive=True, scale=2 ) - - with gr.Row(): - dataset = gr.Dropdown(label="Dataset", choices=list_datasets(), multiselect=True, interactive=True, scale=4) - preview = gr.Button("Preview", visible=False, scale=1) + dataset = gr.Dropdown(label="Dataset", choices=list_datasets(), multiselect=True, interactive=True, scale=2) + preview_btn = gr.Button("Preview", interactive=False, scale=1) preview_box, preview_count, preview_samples = create_preview_box() - dataset.change(can_preview, [dataset], [preview]) - preview.click( + dataset.change(can_preview, [dataset], [preview_btn]) + preview_btn.click( get_preview, [dataset], [preview_count, preview_samples] ).then( lambda: gr.update(visible=True), outputs=[preview_box] @@ -37,16 +36,16 @@ def create_sft_tab(base_model: Component, model_path: Component, checkpoints: Co max_samples = gr.Textbox( label="Max samples", value="100000", info="Number of samples for training.", interactive=True ) - quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Only support 4 bit or 8 bit", - interactive=True) + quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Quantize model to 4/8-bit mode.") with gr.Row(): per_device_train_batch_size = gr.Slider( - label="Batch size", value=4, minimum=1, maximum=128, step=1, info="Train batch size.", interactive=True + label="Batch size", value=4, minimum=1, maximum=128, step=1, + info="Train batch size.", interactive=True ) gradient_accumulation_steps = gr.Slider( - label="Gradient accumulation", value=4, minimum=1, maximum=16, step=1, info='Accumulation steps.', - interactive=True + label="Gradient accumulation", value=4, minimum=1, maximum=32, step=1, + info="Accumulation steps.", interactive=True ) lr_scheduler_type = gr.Dropdown( label="LR Scheduler", value="cosine", info="Scheduler type.", @@ -56,7 +55,7 @@ def create_sft_tab(base_model: Component, model_path: Component, checkpoints: Co with gr.Row(): logging_steps = gr.Slider( - label="Logging steps", value=1, minimum=1, maximum=1000, step=10, + label="Logging steps", value=5, minimum=5, maximum=1000, step=5, info="Number of update steps between two logs.", interactive=True ) save_steps = gr.Slider( @@ -65,8 +64,8 @@ def create_sft_tab(base_model: Component, model_path: Component, checkpoints: Co ) with gr.Row(): - start = gr.Button("Start training") - stop = gr.Button("Abort") + start_btn = gr.Button("Start training") + stop_btn = gr.Button("Abort") with gr.Row(): with gr.Column(scale=4): @@ -76,16 +75,16 @@ def create_sft_tab(base_model: Component, model_path: Component, checkpoints: Co with gr.Column(scale=1): loss_viewer = gr.Plot(label="Loss") - start.click( + start_btn.click( runner.run_train, [ - base_model, model_path, checkpoints, output_dir, finetuning_type, + model_name, model_path, checkpoints, output_dir, finetuning_type, dataset, learning_rate, num_train_epochs, max_samples, fp16, quantization_bit, per_device_train_batch_size, gradient_accumulation_steps, lr_scheduler_type, logging_steps, save_steps ], output_info ) - stop.click(runner.set_abort, queue=False) + stop_btn.click(runner.set_abort, queue=False) - output_info.change(gen_plot, [base_model, output_dir], loss_viewer, queue=False) + output_info.change(gen_plot, [model_name, output_dir], loss_viewer, queue=False) diff --git a/src/glmtuner/webui/interface.py b/src/glmtuner/webui/interface.py index c83560e..f5ece4a 100644 --- a/src/glmtuner/webui/interface.py +++ b/src/glmtuner/webui/interface.py @@ -10,6 +10,7 @@ from glmtuner.webui.css import CSS from glmtuner.webui.runner import Runner + require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") @@ -17,16 +18,16 @@ def create_ui() -> gr.Blocks: runner = Runner() with gr.Blocks(title="Web Tuner", css=CSS) as demo: - base_model, model_list, checkpoints = create_model_tab() + model_name, model_path, checkpoints = create_model_tab() with gr.Tab("SFT"): - create_sft_tab(base_model, model_list, checkpoints, runner) + create_sft_tab(model_name, model_path, checkpoints, runner) with gr.Tab("Evaluate"): - create_eval_tab(base_model, model_list, checkpoints, runner) + create_eval_tab(model_name, model_path, checkpoints, runner) with gr.Tab("Inference"): - create_infer_tab(base_model, model_list, checkpoints) + create_infer_tab(model_name, model_path, checkpoints) return demo @@ -34,4 +35,4 @@ def create_ui() -> gr.Blocks: if __name__ == "__main__": demo = create_ui() demo.queue() - demo.launch(server_name="0.0.0.0", share=True, inbrowser=True) + demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) diff --git a/src/glmtuner/webui/runner.py b/src/glmtuner/webui/runner.py index 4824151..0f85bac 100644 --- a/src/glmtuner/webui/runner.py +++ b/src/glmtuner/webui/runner.py @@ -2,11 +2,11 @@ import os import threading import time -from typing import Optional, Tuple - import transformers +from typing import Optional, Tuple from glmtuner.extras.callbacks import LogCallback +from glmtuner.extras.constants import SUPPORTED_MODELS from glmtuner.extras.logging import LoggerHandler from glmtuner.extras.misc import torch_gc from glmtuner.tuner import get_train_args, run_sft @@ -24,15 +24,24 @@ def set_abort(self): self.aborted = True self.running = False - def initialize(self, base_model: str, model_path: str, dataset: list) -> Tuple[str, LoggerHandler, LogCallback]: + def initialize(self, model_name: str, model_path: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]: if self.running: - return "A process is in running, please abort it firstly.", None, None + return None, "A process is in running, please abort it firstly.", None, None + + if not model_name: + return None, "Please select a model.", None, None - if not base_model: - return "Please select a model.", None, None + if model_path: + if not os.path.isdir(model_path): + return None, "Cannot find model directory in local disk.", None, None + model_name_or_path = model_path + elif model_name in SUPPORTED_MODELS: # TODO: use list in gr.State + model_name_or_path = SUPPORTED_MODELS[model_name]["hf_path"] + else: + return None, "Invalid model.", None, None if len(dataset) == 0: - return "Please choose datasets.", None, None + return None, "Please choose datasets.", None, None self.aborted = False self.running = True @@ -43,7 +52,7 @@ def initialize(self, base_model: str, model_path: str, dataset: list) -> Tuple[s transformers.logging.add_handler(logger_handler) trainer_callback = LogCallback(self) - return "", logger_handler, trainer_callback + return model_name_or_path, "", logger_handler, trainer_callback def finalize(self, finish_info: Optional[str] = None) -> str: self.running = False @@ -54,26 +63,21 @@ def finalize(self, finish_info: Optional[str] = None) -> str: return finish_info if finish_info is not None else "Finished" def run_train( - self, base_model, model_path, checkpoints, output_dir, finetuning_type, - dataset, learning_rate, num_train_epochs, max_samples, - fp16, quantization_bit, per_device_train_batch_size, gradient_accumulation_steps, - lr_scheduler_type, logging_steps, save_steps + self, model_name, model_path, checkpoints, output_dir, finetuning_type, + dataset, learning_rate, num_train_epochs, max_samples, + fp16, quantization_bit, per_device_train_batch_size, gradient_accumulation_steps, + lr_scheduler_type, logging_steps, save_steps ): - error, logger_handler, trainer_callback = self.initialize(base_model, model_path, dataset) + model_name_or_path, error, logger_handler, trainer_callback = self.initialize(model_name, model_path, dataset) if error: yield error return - if get_save_dir(base_model) and checkpoints: - checkpoint_dir = ",".join( - [os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) + if checkpoints: + checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints]) else: checkpoint_dir = None - if model_path: - model_name_or_path = model_path - else: - model_name_or_path = base_model args = dict( model_name_or_path=model_name_or_path, do_train=True, @@ -81,7 +85,7 @@ def run_train( dataset=",".join(dataset), dataset_dir=DATA_DIR, max_samples=int(max_samples), - output_dir=os.path.join(get_save_dir(base_model), output_dir), + output_dir=os.path.join(get_save_dir(model_name), output_dir), checkpoint_dir=checkpoint_dir, overwrite_cache=True, per_device_train_batch_size=per_device_train_batch_size, @@ -92,7 +96,7 @@ def run_train( learning_rate=float(learning_rate), num_train_epochs=float(num_train_epochs), fp16=fp16, - quantization_bit=quantization_bit + quantization_bit=int(quantization_bit) if quantization_bit else None ) model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) @@ -116,26 +120,21 @@ def run_train( yield self.finalize() def run_eval( - self, base_model, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, - quantization_bit + self, model_name, model_path, checkpoints, dataset, max_samples, per_device_eval_batch_size, + quantization_bit ): - error, logger_handler, trainer_callback = self.initialize(base_model, model_path, dataset) + model_name_or_path, error, logger_handler, trainer_callback = self.initialize(model_name, model_path, dataset) if error: yield error return - if get_save_dir(base_model) and checkpoints: - checkpoint_dir = ",".join( - [os.path.join(get_save_dir(base_model), checkpoint) for checkpoint in checkpoints]) - output_dir = os.path.join(get_save_dir(base_model), "eval_" + "_".join(checkpoints)) + if checkpoints: + checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints]) + output_dir = os.path.join(get_save_dir(model_name), "eval_" + "_".join(checkpoints)) else: checkpoint_dir = None - output_dir = os.path.join(get_save_dir(base_model), "eval_base") + output_dir = os.path.join(get_save_dir(model_name), "eval_base") - if model_path: - model_name_or_path = model_path - else: - model_name_or_path = base_model args = dict( model_name_or_path=model_name_or_path, do_eval=True, @@ -147,7 +146,7 @@ def run_eval( overwrite_cache=True, predict_with_generate=True, per_device_eval_batch_size=per_device_eval_batch_size, - quantization_bit=quantization_bit + quantization_bit=int(quantization_bit) if quantization_bit else None ) model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) diff --git a/src/glmtuner/webui/utils.py b/src/glmtuner/webui/utils.py index 5621b84..cf840c7 100644 --- a/src/glmtuner/webui/utils.py +++ b/src/glmtuner/webui/utils.py @@ -28,13 +28,13 @@ def can_preview(dataset: list) -> dict: with open(os.path.join(DATA_DIR, "dataset_info.json"), "r", encoding="utf-8") as f: dataset_info = json.load(f) if ( - len(dataset) > 0 - and "file_name" in dataset_info[dataset[0]] - and os.path.isfile(os.path.join(DATA_DIR, dataset_info[dataset[0]]["file_name"])) + len(dataset) > 0 + and "file_name" in dataset_info[dataset[0]] + and os.path.isfile(os.path.join(DATA_DIR, dataset_info[dataset[0]]["file_name"])) ): - return gr.update(visible=True) + return gr.update(interactive=True) else: - return gr.update(visible=False) + return gr.update(interactive=False) def get_preview(dataset: list) -> Tuple[int, list]: @@ -64,7 +64,7 @@ def gen_plot(base_model: str, output_dir: str) -> matplotlib.figure.Figure: with open(log_file, "r", encoding="utf-8") as f: for line in f: log_info = json.loads(line) - if log_info["loss"]: + if log_info.get("loss", None): steps.append(log_info["current_steps"]) losses.append(log_info["loss"]) ax.plot(steps, losses, alpha=0.4, label="original")