Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

different behavior of argmax for Vector and Matrix leads to error #91

Closed
ArrogantGao opened this issue Mar 8, 2024 · 1 comment
Closed

Comments

@ArrogantGao
Copy link
Contributor

ArrogantGao commented Mar 8, 2024

In the function

function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
    expected_mars = [[l] for l in get_vars(tn)]
    @assert tn.mars[1:length(expected_mars)] == expected_mars "To get the the most probable configuration, the leading elements of `tn.vars` must be `$expected_mars`"
    vars = get_vars(tn)
    tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
    logp, grads = cost_and_gradient(tn.code, tensors)
    # use Array to convert CuArray to CPU arrays
    return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
end

the function argmax is used and its behavior is different for Vector and Matrix

julia> argmax(randn(4))
2

julia> argmax(randn(4, 4))
CartesianIndex(4, 2)

so that argmax(grads[k]) - 1 leads to error if grads[k] is Matrix.

@ArrogantGao
Copy link
Contributor Author

ArrogantGao commented Mar 8, 2024

Sorry it was my mistake, grads[k] should not be a matrix.
When creating the tensor network for MPE, there must be a I vectors for each variable, for example:

julia> factors
10-element Vector{TensorInference.Factor{Float64, 2}}:
 TensorInference.Factor{Float64, 2}((1, 2), [0.1853742993051539 0.24266609993460198; 0.3139558820831363 0.258003718677108])
 TensorInference.Factor{Float64, 2}((1, 3), [0.19490302071525448 0.2331373785245014; 0.37083643432205143 0.2011231664381927])
 TensorInference.Factor{Float64, 2}((1, 4), [0.19159768973795085 0.236442709501805; 0.2699053176129627 0.30205428314728144])
 TensorInference.Factor{Float64, 2}((1, 5), [0.16153413648476111 0.2665062627549947; 0.29757932897291595 0.2743802717873282])
 TensorInference.Factor{Float64, 2}((2, 3), [0.31429995086808404 0.18503023052020612; 0.25143950416922195 0.24923031444248797])
 TensorInference.Factor{Float64, 2}((2, 4), [0.23558618140615134 0.2637439999821388; 0.22591682594476228 0.27475299266694764])
 TensorInference.Factor{Float64, 2}((2, 5), [0.2561260234949363 0.24320415789335384; 0.2029874419627408 0.29768237664896907])
 TensorInference.Factor{Float64, 2}((3, 4), [0.2893861660460715 0.2763532889912345; 0.17211684130484214 0.2621437036578519])
 TensorInference.Factor{Float64, 2}((3, 5), [0.27993586794877484 0.28580358708853115; 0.17917759750890222 0.2550829474537918])
 TensorInference.Factor{Float64, 2}((4, 5), [0.24961217921386344 0.21189082813705018; 0.20950128624381362 0.32899570640527276])

julia> tn.tensors
15-element Vector{Array{Float64}}:
 [1.0, 1.0]
 [1.0, 1.0]
 [1.0, 1.0]
 [1.0, 1.0]
 [1.0, 1.0]
 [0.1853742993051539 0.24266609993460198; 0.3139558820831363 0.258003718677108]
 [0.19490302071525448 0.2331373785245014; 0.37083643432205143 0.2011231664381927]
 [0.19159768973795085 0.236442709501805; 0.2699053176129627 0.30205428314728144]
 [0.16153413648476111 0.2665062627549947; 0.29757932897291595 0.2743802717873282]
 [0.31429995086808404 0.18503023052020612; 0.25143950416922195 0.24923031444248797]
 [0.23558618140615134 0.2637439999821388; 0.22591682594476228 0.27475299266694764]
 [0.2561260234949363 0.24320415789335384; 0.2029874419627408 0.29768237664896907]
 [0.2893861660460715 0.2763532889912345; 0.17211684130484214 0.2621437036578519]
 [0.27993586794877484 0.28580358708853115; 0.17917759750890222 0.2550829474537918]
 [0.24961217921386344 0.21189082813705018; 0.20950128624381362 0.32899570640527276]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant