diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 6126c1d8f1..0c9bab9e25 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -166,13 +166,20 @@ Generally, users don't have to worry about these internal details. We provide a common constructor `MvNormal`, which will construct a distribution of appropriate type depending on the input arguments. """ -struct MvNormal{T<:Real,Cov<:AbstractPDMat,Mean<:AbstractVector} <: AbstractMvNormal +struct MvNormal{T<:Real,Cov<:AbstractPDMat{T},Mean<:AbstractVector{T}} <: AbstractMvNormal μ::Mean Σ::Cov + + function MvNormal{T,Cov,Mean}(µ, Σ) where {T<:Real,Cov<:AbstractPDMat{T},Mean<:AbstractVector{T}} + axes(Σ, 1) == eachindex(μ) || throw(DimensionMismatch("The dimensions of µ and Σ are inconsistent.")) + T(Inf) # we require that Inf be in the domain of T, see `insupport` + return new{T,Cov,Mean}(µ, Σ) + end end const MultivariateNormal = MvNormal # for the purpose of backward compatibility +# TODO?: make these IsoNormal{T} etc const IsoNormal = MvNormal{Float64,ScalMat{Float64},Vector{Float64}} const DiagNormal = MvNormal{Float64,PDiagMat{Float64,Vector{Float64}},Vector{Float64}} const FullNormal = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},Vector{Float64}} @@ -182,32 +189,49 @@ const ZeroMeanDiagNormal{Axes} = MvNormal{Float64,PDiagMat{Float64,Vector{Float6 const ZeroMeanFullNormal{Axes} = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},Zeros{Float64,1,Axes}} ### Construction -function MvNormal(μ::AbstractVector{T}, Σ::AbstractPDMat{T}) where {T<:Real} - size(Σ, 1) == length(μ) || throw(DimensionMismatch("The dimensions of mu and Sigma are inconsistent.")) - MvNormal{T,typeof(Σ), typeof(μ)}(μ, Σ) +## Constructor that accepts an `AbstractPDMat` but coerces only T and Cov +function MvNormal{T,Cov}(μ, Σ::AbstractPDMat) where {T<:Real,Cov<:AbstractPDMat{T}} + # General pattern: `convert(Typ, x)::Typ` is used to coerce `x` to type `Typ` + # This guards against broken implementations of `convert` that otherwise risk StackOverflowError + μ = convert(AbstractVector{T}, μ)::AbstractVector{T} + return MvNormal{T,Cov,typeof(μ)}(μ, Σ) end -function MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractPDMat{<:Real}) - R = Base.promote_eltype(μ, Σ) - MvNormal(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, Σ)) +## Constructor that accepts an `AbstractPDMat` but coerces only T +function MvNormal{T}(μ, Σ::AbstractPDMat) where {T<:Real} + Σ = convert(AbstractPDMat{T}, Σ)::AbstractPDMat{T} + return MvNormal{T,typeof(Σ)}(μ, Σ) end +## Constructor that accepts an `AbstractPDMat` without any coercion +function MvNormal(μ, Σ::AbstractPDMat) + T = promote_type(eltype(μ), eltype(Σ)) + return MvNormal{T}(μ, Σ) +end + +## Coercing constructors that accept a general covariance matrix +MvNormal{T,Cov}(μ, Σ::AbstractMatrix) where {T<:Real,Cov<:AbstractPDMat{T}} = + MvNormal{T,Cov}(μ, Cov(Σ)) +MvNormal{T,Cov}(μ, Σ::UniformScaling) where {T<:Real,Cov<:AbstractPDMat{T}} = + MvNormal{T,Cov}(μ, pdmat(T, length(μ), Σ)) +MvNormal{T}(μ, Σ::AbstractMatrix) where {T<:Real} = MvNormal{T}(μ, pdmat(T, Σ)) +MvNormal{T}(μ, Σ::UniformScaling) where {T<:Real} = MvNormal{T}(μ, pdmat(T, length(μ), Σ)) + # constructor with general covariance matrix """ MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractMatrix{<:Real}) Construct a multivariate normal distribution with mean `μ` and covariance matrix `Σ`. """ -MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractMatrix{<:Real}) = MvNormal(μ, PDMat(Σ)) -MvNormal(μ::AbstractVector{<:Real}, Σ::Diagonal{<:Real}) = MvNormal(μ, PDiagMat(Σ.diag)) -MvNormal(μ::AbstractVector{<:Real}, Σ::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) = MvNormal(μ, PDiagMat(Σ.data.diag)) -MvNormal(μ::AbstractVector{<:Real}, Σ::UniformScaling{<:Real}) = - MvNormal(μ, ScalMat(length(μ), Σ.λ)) -function MvNormal( - μ::AbstractVector{<:Real}, Σ::Diagonal{<:Real,<:FillArrays.AbstractFill{<:Real,1}} -) - return MvNormal(μ, ScalMat(size(Σ, 1), FillArrays.getindex_value(Σ.diag))) -end +MvNormal(μ, Σ::AbstractMatrix) = MvNormal{promote_type(eltype(μ), eltype(Σ))}(μ, Σ) +MvNormal(μ, Σ::UniformScaling) = MvNormal{promote_type(eltype(μ), eltype(Σ))}(μ, Σ) + +pdmat(::Type{T}, Σ::AbstractMatrix{<:Real}) where {T<:Real} = PDMat{T}(Σ) +pdmat(::Type{T}, Σ::Diagonal{<:Real}) where {T<:Real} = PDiagMat{T}(Σ.diag) +pdmat(::Type{T}, Σ::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) where {T<:Real} = PDiagMat{T}(Σ.data.diag) +pdmat(::Type{T}, n::Integer, Σ::UniformScaling{<:Real}) where {T<:Real} = ScalMat{T}(n, Σ.λ) +pdmat(::Type{T}, Σ::Diagonal{<:Real,<:FillArrays.AbstractFill{<:Real,1}}) where {T<:Real} = + ScalMat{T}(size(Σ, 1), FillArrays.getindex_value(Σ.diag)) # constructor without mean vector """