Skip to content

Commit

Permalink
read local checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 14, 2024
1 parent 97962f7 commit 72dd739
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ _version.py
*.pickle
hub/
.vscode/settings.json
*.swp
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
> [!CAUTION]
> Please do not install from Pypi, the package there with the same name is not this one.
> To install, run:
>
>
> `pip install git+https://github.com/ecmwf-lab/ai-models-aurora.git`
# ai-models-aurora
Expand All @@ -22,5 +22,5 @@ The model weights are made available for use under the terms of the Creative Com
> [!CAUTION]
> Please do not install from Pypi, the package there with the same name is not this one.
> To install, run:
>
>
> `pip install git+https://github.com/ecmwf-lab/ai-models-aurora.git`
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dynamic = [
"version",
]
dependencies = [
"microsoft-aurora",
"microsoft-aurora>=1.1.0",
"ai-models>=0.6.4"
]

Expand Down
15 changes: 13 additions & 2 deletions src/ai_models_aurora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,19 @@ def run(self):

model = self.klass(use_lora=self.use_lora)
model = model.to(self.device)
LOG.info("Downloading Aurora model %s", self.checkpoint)
model.load_checkpoint("microsoft/aurora", self.checkpoint, strict=False)

path = os.path.join(self.assets, os.path.basename(self.checkpoint))
if os.path.exists(path):
LOG.info("Loading Aurora model from %s", path)
model.load_checkpoint_local(path, strict=False)
else:
LOG.info("Downloading Aurora model %s", self.checkpoint)
try:
model.load_checkpoint("microsoft/aurora", self.checkpoint, strict=False)
except Exception:
LOG.error("Did not find a local copy at %s", path)
raise

LOG.info("Loading Aurora model to device %s", self.device)

model = model.to(self.device)
Expand Down

0 comments on commit 72dd739

Please sign in to comment.