diff --git a/.gitignore b/.gitignore index 5165aa2..80a92ac 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ _version.py *.pickle hub/ .vscode/settings.json +*.swp diff --git a/README.md b/README.md index 3c29952..ef80cd7 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ > [!CAUTION] > Please do not install from Pypi, the package there with the same name is not this one. > To install, run: -> +> > `pip install git+https://github.com/ecmwf-lab/ai-models-aurora.git` # ai-models-aurora @@ -22,5 +22,5 @@ The model weights are made available for use under the terms of the Creative Com > [!CAUTION] > Please do not install from Pypi, the package there with the same name is not this one. > To install, run: -> +> > `pip install git+https://github.com/ecmwf-lab/ai-models-aurora.git` diff --git a/pyproject.toml b/pyproject.toml index 102eb09..8ac0b21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dynamic = [ "version", ] dependencies = [ - "microsoft-aurora", + "microsoft-aurora>=1.1.0", "ai-models>=0.6.4" ] diff --git a/src/ai_models_aurora/model.py b/src/ai_models_aurora/model.py index 5e11198..f5cc039 100644 --- a/src/ai_models_aurora/model.py +++ b/src/ai_models_aurora/model.py @@ -51,8 +51,19 @@ def run(self): model = self.klass(use_lora=self.use_lora) model = model.to(self.device) - LOG.info("Downloading Aurora model %s", self.checkpoint) - model.load_checkpoint("microsoft/aurora", self.checkpoint, strict=False) + + path = os.path.join(self.assets, os.path.basename(self.checkpoint)) + if os.path.exists(path): + LOG.info("Loading Aurora model from %s", path) + model.load_checkpoint_local(path, strict=False) + else: + LOG.info("Downloading Aurora model %s", self.checkpoint) + try: + model.load_checkpoint("microsoft/aurora", self.checkpoint, strict=False) + except Exception: + LOG.error("Did not find a local copy at %s", path) + raise + LOG.info("Loading Aurora model to device %s", self.device) model = model.to(self.device)