diff --git a/Project.toml b/Project.toml index 38f758e..92c0132 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] ForwardDiff = "0.10" @@ -18,12 +19,12 @@ StaticArrays = "1" LinearAlgebra = "1" Statistics = "1" julia = "1" +Random = "1" [extras] -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Random", "Test", "TimerOutputs", "Unitful"] +test = ["Test", "TimerOutputs", "Unitful"] diff --git a/docs/Manifest.toml b/docs/Manifest.toml index efd3846..921cd57 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.3" +julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "063b18da735845e327fc97d98706d227f804d89d" +project_hash = "7cec80757d87f9c7f22da03f741d383cecfafe9e" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -28,7 +28,7 @@ version = "0.3.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" +version = "1.1.0+0" [[deps.Dates]] deps = ["Printf"] @@ -68,9 +68,9 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.35" +version = "0.10.36" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] @@ -78,9 +78,9 @@ weakdeps = ["StaticArrays"] [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" +git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.3" +version = "0.2.4" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -92,10 +92,10 @@ uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" [[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" +version = "1.5.0" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -106,21 +106,26 @@ version = "0.21.4" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" +version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" +version = "8.4.0+0" [[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" +version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -131,9 +136,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.24" +version = "0.3.27" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -150,9 +155,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.10" +version = "0.5.13" [[deps.Markdown]] deps = ["Base64"] @@ -161,14 +166,14 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" +version = "2.28.2+1" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" +version = "2023.1.10" [[deps.NaNMath]] deps = ["OpenLibm_jll"] @@ -183,12 +188,12 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" +version = "0.3.23+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" +version = "0.8.1+2" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -198,26 +203,26 @@ version = "0.5.5+0" [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.7.2" +version = "2.8.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" +version = "1.10.0" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.1.2" +version = "1.2.1" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" +version = "1.4.3" [[deps.Printf]] deps = ["Unicode"] @@ -228,7 +233,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA", "Serialization"] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.SHA]] @@ -237,9 +242,9 @@ version = "0.7.0" [[deps.SIMD]] deps = ["PrecompileTools"] -git-tree-sha1 = "0e270732477b9e551d884e6b07e23bb2ec947790" +git-tree-sha1 = "d8911cc125da009051fb35322415641d02d9e37f" uuid = "fdea26ae-647d-5447-a871-4b548cad5224" -version = "3.4.5" +version = "3.4.6" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -250,12 +255,13 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.0" +version = "2.3.1" [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -264,15 +270,19 @@ version = "2.3.0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.2" -weakdeps = ["Statistics"] +version = "1.9.3" [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" StaticArraysStatisticsExt = "Statistics" + [deps.StaticArrays.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -281,12 +291,12 @@ version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" +version = "1.10.0" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" +version = "7.2.1+1" [[deps.TOML]] deps = ["Dates"] @@ -299,7 +309,7 @@ uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" [[deps.Tensors]] -deps = ["ForwardDiff", "LinearAlgebra", "PrecompileTools", "SIMD", "StaticArrays", "Statistics"] +deps = ["ForwardDiff", "LinearAlgebra", "PrecompileTools", "Random", "SIMD", "StaticArrays", "Statistics"] path = ".." uuid = "48a634ad-e948-5137-8d70-aa71f2a747f4" version = "1.16.1" @@ -318,19 +328,19 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" +version = "1.2.13+1" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" +version = "5.8.0+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" +version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" +version = "17.4.0+2" diff --git a/docs/Project.toml b/docs/Project.toml index cb30e2e..ca54c5ce 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,6 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" + +[compat] +Documenter = "0" \ No newline at end of file diff --git a/src/Tensors.jl b/src/Tensors.jl index 4cd677b..625042b 100644 --- a/src/Tensors.jl +++ b/src/Tensors.jl @@ -2,7 +2,7 @@ module Tensors import Base.@pure -import Statistics +import Statistics, Random using Statistics: mean using LinearAlgebra using StaticArrays diff --git a/src/constructors.jl b/src/constructors.jl index 7a48fcc..db15e4d 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -88,8 +88,10 @@ for TensorType in (SymmetricTensor, Tensor) end end -# zero, one, rand -for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T))), (:rand, :(()->rand(T))), (:randn,:(()->randn(T)))) +# zero, one, randn. +# rand included here to make rand(::Type{AbstractTensor}) fast on julia 1.6. +# When 1.6 support is dropped, the general implementation below can be used instead. +for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T))), (:randn,:(()->randn(T))), (:rand,:(()->rand(T)))) for TensorType in (SymmetricTensor, Tensor) @eval begin @inline Base.$op(::Type{$TensorType{order, dim}}) where {order, dim} = $op($TensorType{order, dim, Float64}) @@ -101,6 +103,27 @@ end @eval @inline Base.$op(t::AllTensors) = $op(typeof(t)) end +# Helper to construct a fully specified tensor from at least the specification returned by get_base +function default_concrete_tensor_type(::Type{X}) where {X <: Union{Tensor, SymmetricTensor}} + TB = get_base(X) + T = eltype(X) === Any ? Float64 : eltype(X) + M = n_components(TB) + return TB{T, M} +end + +# For `rand`, hook into Random +function Random.rand(rng::Random.AbstractRNG, ::Random.SamplerType{TT}) where {TT <: Union{Tensor{order, dim}, SymmetricTensor{order, dim}}} where {order, dim} + TC = default_concrete_tensor_type(TT) + return apply_all(get_base(TT), _ -> rand(rng, eltype(TC)))::TC # typeassert needed on julia 1.6, but ok on 1.8 and later. +end +# Always use the `SamplerType` as the value has no influence on the random generation. +Random.Sampler(::Type{<:Random.AbstractRNG}, t::AllTensors, ::Random.Repetition) = Random.SamplerType{typeof(t)}() +# Fix to make `rand([rng], ::Type{AbstractTensor}, d, dims...)` have a concrete eltype +function Random.rand(r::Random.AbstractRNG, ::Type{X}, dims::Dims) where {X <: Union{Tensor, SymmetricTensor}} + TC = default_concrete_tensor_type(X) + return Random.rand!(r, Array{TC}(undef, dims), X) +end + @inline Base.fill(el::Number, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> el) @inline Base.fill(f::Function, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> f()) diff --git a/test/test_misc.jl b/test/test_misc.jl index 05fa1aa..00904d3 100644 --- a/test/test_misc.jl +++ b/test/test_misc.jl @@ -19,6 +19,20 @@ for T in (Float32, Float64, F64), dim in (1,2,3), order in (1,2,3,4) (@inferred (op)(Vec{dim, T}))::Tensor{order, dim, T} end end + # Random numbers with samplers + for TensorType in (Tensor, SymmetricTensor) + TensorType == SymmetricTensor && isodd(order) && continue + TT = TensorType{order, dim, T} + @test rand(MersenneTwister(1), TT) ≈ rand(MersenneTwister(1), TT) # Check that rng was actually used + @test rand(MersenneTwister(2), TT) ≈ rand(MersenneTwister(2), rand(TT)) # Check same value when given a value + # type stability + (@inferred rand(TT, 2))::Vector{TT{Tensors.n_components(TensorType{order, dim})}} + @inferred rand(MersenneTwister(2), TT) + + if order == 1 + @test rand(MersenneTwister(1), Vec{dim, T}) ≈ rand(MersenneTwister(1), TT) + end + end # Special Vec constructor if order == 1 t = ntuple(i -> T(i), dim)