Skip to content

Commit

Permalink
model names typo
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 15, 2024
1 parent af71748 commit 06aabac
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/ai_models_aurora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ def nan_extend(self, data):
axis=0,
)

def add_model_args(self, parser):
def parse_model_args(self, args):
import argparse

parser = argparse.ArgumentParser("ai-models aurora")

parser.add_argument(
"--lora",
type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
Expand All @@ -171,8 +175,10 @@ def add_model_args(self, parser):
help="Use LoRA model (true/false). Default depends on the model.",
)

return parser.parse_args(args)


class Aurora2p5(AuroraModel):
class Aurora0p25(AuroraModel):
klass = Aurora
download_files = ("aurora-0.25-static.pickle",)
# Input
Expand All @@ -181,13 +187,13 @@ class Aurora2p5(AuroraModel):


# https://microsoft.github.io/aurora/models.html#aurora-0-25-pretrained
class Aurora2p5Pretrained(Aurora2p5):
class Aurora0p25Pretrained(Aurora0p25):
use_lora = False
checkpoint = "aurora-0.25-pretrained.ckpt"


# https://microsoft.github.io/aurora/models.html#aurora-0-25-fine-tuned
class Aurora2p5FineTuned(Aurora2p5):
class Aurora025FineTuned(Aurora0p25):
use_lora = True
checkpoint = "aurora-0.25-finetuned.ckpt"

Expand Down Expand Up @@ -234,8 +240,8 @@ def model(model_version, **kwargs):
# select with --model-version

models = {
"2.5-pretrained": Aurora2p5Pretrained,
"2.5-finetuned": Aurora2p5FineTuned,
"0.25-pretrained": Aurora0p25Pretrained,
"0.25-finetuned": Aurora025FineTuned,
"0.1-finetuned": Aurora0p1FineTuned,
"default": Aurora0p1FineTuned,
"latest": Aurora0p1FineTuned, # Backward compatibility
Expand Down

0 comments on commit 06aabac

Please sign in to comment.