diff --git a/src/ai_models_aurora/model.py b/src/ai_models_aurora/model.py index 9a47282..8f47e2e 100644 --- a/src/ai_models_aurora/model.py +++ b/src/ai_models_aurora/model.py @@ -161,7 +161,11 @@ def nan_extend(self, data): axis=0, ) - def add_model_args(self, parser): + def parse_model_args(self, args): + import argparse + + parser = argparse.ArgumentParser("ai-models aurora") + parser.add_argument( "--lora", type=lambda x: (str(x).lower() in ["true", "1", "yes"]), @@ -171,8 +175,10 @@ def add_model_args(self, parser): help="Use LoRA model (true/false). Default depends on the model.", ) + return parser.parse_args(args) + -class Aurora2p5(AuroraModel): +class Aurora0p25(AuroraModel): klass = Aurora download_files = ("aurora-0.25-static.pickle",) # Input @@ -181,13 +187,13 @@ class Aurora2p5(AuroraModel): # https://microsoft.github.io/aurora/models.html#aurora-0-25-pretrained -class Aurora2p5Pretrained(Aurora2p5): +class Aurora0p25Pretrained(Aurora0p25): use_lora = False checkpoint = "aurora-0.25-pretrained.ckpt" # https://microsoft.github.io/aurora/models.html#aurora-0-25-fine-tuned -class Aurora2p5FineTuned(Aurora2p5): +class Aurora025FineTuned(Aurora0p25): use_lora = True checkpoint = "aurora-0.25-finetuned.ckpt" @@ -234,8 +240,8 @@ def model(model_version, **kwargs): # select with --model-version models = { - "2.5-pretrained": Aurora2p5Pretrained, - "2.5-finetuned": Aurora2p5FineTuned, + "0.25-pretrained": Aurora0p25Pretrained, + "0.25-finetuned": Aurora025FineTuned, "0.1-finetuned": Aurora0p1FineTuned, "default": Aurora0p1FineTuned, "latest": Aurora0p1FineTuned, # Backward compatibility