Skip to content

Commit

Permalink
save model state
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Sep 18, 2024
1 parent 1ee38e0 commit ceb04ce
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
Expand Down
1 change: 1 addition & 0 deletions src/DINCAE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using Printf
using Statistics
using ThreadsX
using ChainRulesCore
using JLD2

import Base: length
import Base: size
Expand Down
11 changes: 11 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ function reconstruct(Atype,data_all,fnames_rec;
paramfile = nothing,
laplacian_penalty = 0,
laplacian_error_penalty = laplacian_penalty,
modeldir = nothing,
)
DB(Atype,d,batch_size) = (Atype.(tmp) for tmp in DataLoader(d,batch_size))

Expand Down Expand Up @@ -529,6 +530,16 @@ function reconstruct(Atype,data_all,fnames_rec;
flush(stdout)

if e save_epochs
if !isnothing(modeldir)
println("Save model $e")
@show cpu(model).chains[1][1].weight[1,1,1,1]
model_fname = joinpath(modeldir,"model-checkpoint-" * @sprintf("%05d",e) * ".jld2")
model_state = Flux.state(cpu(model));
jldsave(model_fname; model_state, train_mean_data,
ntime_win, is3D, cycle_periods, remove_mean,
direction_obs, e)
end

println("Save output $e")

for (d_iter,ds_) in zip(data_iter[2:end],ds)
Expand Down

0 comments on commit ceb04ce

Please sign in to comment.