Skip to content

Commit

Permalink
fix genmodel
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Dec 6, 2024
1 parent 07a6ee5 commit c92f441
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
12 changes: 6 additions & 6 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ function recmodel4(sz,enc_nfilter,dec_nfilter,skipconnections,l=1; method = :nea
end
end

function genmodel(sz,noutput;
function genmodel(sz,ninput,noutput;
truth_uncertain = false,
enc_nfilter_internal = [16,24,36,54],
skipconnections = 2:(length(enc_nfilter_internal)+1),
Expand All @@ -300,7 +300,7 @@ function genmodel(sz,noutput;
laplacian_error_penalty = laplacian_penalty,
)

nvar = sz[end-1]
nvar = ninput
enc_nfilter = vcat([nvar],enc_nfilter_internal)

if output_ndims == 1
Expand All @@ -323,7 +323,7 @@ function genmodel(sz,noutput;
@info "Number of filters: $enc_nfilter"
if loss_weights_refine == (1.,)
steps = (DINCAE.recmodel4(
sz[1:end-2],
sz,
enc_nfilter,
dec_nfilter,
skipconnections,
Expand All @@ -337,8 +337,8 @@ function genmodel(sz,noutput;
@info "Number of filters in encoder (refinement): $enc_nfilter2"
@info "Number of filters in decoder (refinement): $dec_nfilter2"

steps = (DINCAE.recmodel4(sz[1:end-2],enc_nfilter,dec_nfilter,skipconnections; method = upsampling_method),
DINCAE.recmodel4(sz[1:end-2],enc_nfilter2,dec_nfilter2,skipconnections; method = upsampling_method))
steps = (DINCAE.recmodel4(sz,enc_nfilter,dec_nfilter,skipconnections; method = upsampling_method),
DINCAE.recmodel4(sz,enc_nfilter2,dec_nfilter2,skipconnections; method = upsampling_method))
end

if output_ndims == 1
Expand Down Expand Up @@ -510,7 +510,7 @@ function reconstruct(Atype,data_all,fnames_rec;
@info "Input sum: $(sum(inputs_))"


model = genmodel(sz,noutput;
model = genmodel(sz[1:end-2],sz[end-1],noutput;
enc_nfilter_internal,
upsampling_method,
skipconnections,
Expand Down
3 changes: 2 additions & 1 deletion src/points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,10 @@ function reconstruct_points(

nvar = sz[3]
@info "number of variables: $nvar"
@show sz
noutput = 1

model = genmodel(sz,noutput;
model = genmodel(sz[1:end-1],nvar,noutput;
enc_nfilter_internal,
upsampling_method,
skipconnections,
Expand Down

0 comments on commit c92f441

Please sign in to comment.