From 49c15bf06bb61b8ea538166804ae2bbf829f7b3d Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Fri, 15 Mar 2024 21:35:30 +0100 Subject: [PATCH 1/9] Initial commit --- Project.toml | 7 ++++--- src/constructors.jl | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 38f758ea..d86f8b78 100644 --- a/Project.toml +++ b/Project.toml @@ -5,17 +5,18 @@ version = "1.16.1" [deps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ForwardDiff = "0.10" -SIMD = "2, 3" +LinearAlgebra = "1" PrecompileTools = "1" +SIMD = "2, 3" StaticArrays = "1" -LinearAlgebra = "1" Statistics = "1" julia = "1" diff --git a/src/constructors.jl b/src/constructors.jl index 7a48fcca..6112ec34 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -89,7 +89,7 @@ for TensorType in (SymmetricTensor, Tensor) end # zero, one, rand -for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T))), (:rand, :(()->rand(T))), (:randn,:(()->randn(T)))) +for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T)))) #, (:rand, :(()->rand(T))), (:randn,:(()->randn(T)))) for TensorType in (SymmetricTensor, Tensor) @eval begin @inline Base.$op(::Type{$TensorType{order, dim}}) where {order, dim} = $op($TensorType{order, dim, Float64}) @@ -101,6 +101,22 @@ end @eval @inline Base.$op(t::AllTensors) = $op(typeof(t)) end +# Random sampling +import Random +_default_eltype(::Type{Tensor{order, dim, T, M} where T}) where {order, dim, M} = Float64 +_default_eltype(::Type{Tensor{order, dim}}) where {order, dim} = Float64 +_default_eltype(::Type{SymmetricTensor{order, dim}}) where {order, dim} = Float64 +_default_eltype(TT::Type{<:AbstractTensor}) = eltype(TT) + +function Random.rand(rng::Random.AbstractRNG, ::Random.SamplerType{TT}) where {TT <: AllTensors} + T = _default_eltype(TT) + return apply_all(get_base(TT), _ -> randn(rng, T)) +end +function Random.randn(rng::Random.AbstractRNG, ::Random.SamplerType{<:TT}) where {TT <: AbstractTensor} + T = _default_eltype(TT) + return apply_all(get_base(TT), _ -> randn(rng, T)) +end + @inline Base.fill(el::Number, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> el) @inline Base.fill(f::Function, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> f()) From e68937177decc907028db4a77938b4c25abcdb34 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Fri, 15 Mar 2024 22:36:09 +0100 Subject: [PATCH 2/9] Add tests and make it work generally --- src/Tensors.jl | 2 +- src/constructors.jl | 15 ++++++--------- test/test_misc.jl | 11 +++++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/Tensors.jl b/src/Tensors.jl index 4cd677b2..625042b9 100644 --- a/src/Tensors.jl +++ b/src/Tensors.jl @@ -2,7 +2,7 @@ module Tensors import Base.@pure -import Statistics +import Statistics, Random using Statistics: mean using LinearAlgebra using StaticArrays diff --git a/src/constructors.jl b/src/constructors.jl index 6112ec34..c1024aa7 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -88,8 +88,8 @@ for TensorType in (SymmetricTensor, Tensor) end end -# zero, one, rand -for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T)))) #, (:rand, :(()->rand(T))), (:randn,:(()->randn(T)))) +# zero, one, randn +for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T))), (:randn,:(()->randn(T)))) for TensorType in (SymmetricTensor, Tensor) @eval begin @inline Base.$op(::Type{$TensorType{order, dim}}) where {order, dim} = $op($TensorType{order, dim, Float64}) @@ -101,21 +101,18 @@ end @eval @inline Base.$op(t::AllTensors) = $op(typeof(t)) end -# Random sampling -import Random +# For `rand`, hook into Random _default_eltype(::Type{Tensor{order, dim, T, M} where T}) where {order, dim, M} = Float64 _default_eltype(::Type{Tensor{order, dim}}) where {order, dim} = Float64 _default_eltype(::Type{SymmetricTensor{order, dim}}) where {order, dim} = Float64 _default_eltype(TT::Type{<:AbstractTensor}) = eltype(TT) -function Random.rand(rng::Random.AbstractRNG, ::Random.SamplerType{TT}) where {TT <: AllTensors} - T = _default_eltype(TT) - return apply_all(get_base(TT), _ -> randn(rng, T)) -end -function Random.randn(rng::Random.AbstractRNG, ::Random.SamplerType{<:TT}) where {TT <: AbstractTensor} +function Random.rand(rng::Random.AbstractRNG, ::Random.SamplerType{TT}) where {TT <: Union{Tensor{order, dim}, SymmetricTensor{order, dim}}} where {order, dim} T = _default_eltype(TT) return apply_all(get_base(TT), _ -> randn(rng, T)) end +# Always use the `SamplerType` as the value has no influence on the random generation. +Random.Sampler(::Type{<:Random.AbstractRNG}, t::AllTensors, ::Random.Repetition) = Random.SamplerType{typeof(t)}() @inline Base.fill(el::Number, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> el) @inline Base.fill(f::Function, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> f()) diff --git a/test/test_misc.jl b/test/test_misc.jl index 05fa1aab..d310f980 100644 --- a/test/test_misc.jl +++ b/test/test_misc.jl @@ -19,6 +19,17 @@ for T in (Float32, Float64, F64), dim in (1,2,3), order in (1,2,3,4) (@inferred (op)(Vec{dim, T}))::Tensor{order, dim, T} end end + # Random numbers with samplers + for TensorType in (Tensor, SymmetricTensor) + TensorType == SymmetricTensor && isodd(order) && continue + TT = TensorType{order, dim, T} + @test rand(MersenneTwister(1), TT) ≈ rand(MersenneTwister(1), TT) # Check that rng was actually used + @test rand(MersenneTwister(2), TT) ≈ rand(MersenneTwister(2), rand(TT)) # Check same value when given a value + @inferred Vector{<:TT} rand(TT, 2) # Construct a Vector of random tensors + if order == 1 + @test rand(MersenneTwister(1), Vec{dim}) ≈ rand(MersenneTwister(1), TT) + end + end # Special Vec constructor if order == 1 t = ntuple(i -> T(i), dim) From 38aee8ee9a1dd46a45ec874f692e3ac15448a0e0 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Fri, 15 Mar 2024 22:46:38 +0100 Subject: [PATCH 3/9] Clean Project.toml --- Project.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index d86f8b78..c9427e04 100644 --- a/Project.toml +++ b/Project.toml @@ -13,18 +13,18 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ForwardDiff = "0.10" -LinearAlgebra = "1" -PrecompileTools = "1" SIMD = "2, 3" +PrecompileTools = "1" StaticArrays = "1" +LinearAlgebra = "1" Statistics = "1" julia = "1" +Random = "1" [extras] -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Random", "Test", "TimerOutputs", "Unitful"] +test = ["Test", "TimerOutputs", "Unitful"] From 3902f33584020728a5968d1cb3d349e9dc0deab5 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Fri, 15 Mar 2024 22:47:54 +0100 Subject: [PATCH 4/9] Order deps as they where, add Random at the end --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c9427e04..92c0132d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,11 @@ version = "1.16.1" [deps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] ForwardDiff = "0.10" From d3743ef1ac49f37d14034cfb58108548cb8c9e9a Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Fri, 15 Mar 2024 22:54:18 +0100 Subject: [PATCH 5/9] Update docs Manifest with new Random dep to Tensors --- docs/Manifest.toml | 178 +++++++++++++++++++++++++++++++++------------ 1 file changed, 130 insertions(+), 48 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index efd38464..1e134f60 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.3" +julia_version = "1.10.2" manifest_format = "2.0" project_hash = "063b18da735845e327fc97d98706d227f804d89d" @@ -9,6 +9,11 @@ git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" version = "0.0.1" +[[deps.AbstractTrees]] +git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.5" + [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" @@ -19,6 +24,12 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.4" + [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -28,7 +39,7 @@ version = "0.3.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" +version = "1.1.0+0" [[deps.Dates]] deps = ["Printf"] @@ -53,34 +64,52 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.9.3" [[deps.Documenter]] -deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" +deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] +git-tree-sha1 = "4a40af50e8b24333b9ec6892546d9ca5724228eb" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.25" +version = "1.3.0" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" +[[deps.Expat_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" +uuid = "2e619515-83b5-522b-bb60-26c02a35a201" +version = "2.5.0+0" + [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.35" +version = "0.10.36" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] ForwardDiffStaticArraysExt = "StaticArrays" +[[deps.Git]] +deps = ["Git_jll"] +git-tree-sha1 = "04eff47b1354d702c3a85e8ab23d539bb7d5957e" +uuid = "d7ba0133-e1db-5d97-8f8c-041e4b3a1eb2" +version = "1.3.1" + +[[deps.Git_jll]] +deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] +git-tree-sha1 = "12945451c5d0e2d0dca0724c3a8d6448b46bbdf9" +uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" +version = "2.44.0+1" + [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" +git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.3" +version = "0.2.4" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -92,10 +121,10 @@ uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" [[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" +version = "1.5.0" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -103,37 +132,53 @@ git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.4" +[[deps.LazilyInitializedFields]] +git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" +uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" +version = "1.2.2" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" +version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" +version = "8.4.0+0" [[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" +version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" + [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.24" +version = "0.3.27" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -150,25 +195,31 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.10" +version = "0.5.13" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[deps.MarkdownAST]] +deps = ["AbstractTrees", "Markdown"] +git-tree-sha1 = "465a70f0fc7d443a00dcdc3267a497397b8a3899" +uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" +version = "0.1.2" + [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" +version = "2.28.2+1" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" +version = "2023.1.10" [[deps.NaNMath]] deps = ["OpenLibm_jll"] @@ -183,12 +234,18 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" +version = "0.3.23+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" +version = "0.8.1+2" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "60e3045590bd104a16fefb12836c00c0ef8c7f8c" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.13+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -196,28 +253,33 @@ git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.5+0" +[[deps.PCRE2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" +version = "10.42.0+1" + [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.7.2" +version = "2.8.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" +version = "1.10.0" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.1.2" +version = "1.2.1" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" +version = "1.4.3" [[deps.Printf]] deps = ["Unicode"] @@ -228,18 +290,24 @@ deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.Random]] -deps = ["SHA", "Serialization"] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.RegistryInstances]] +deps = ["LazilyInitializedFields", "Pkg", "TOML", "Tar"] +git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" +uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3" +version = "0.1.0" + [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" [[deps.SIMD]] deps = ["PrecompileTools"] -git-tree-sha1 = "0e270732477b9e551d884e6b07e23bb2ec947790" +git-tree-sha1 = "d8911cc125da009051fb35322415641d02d9e37f" uuid = "fdea26ae-647d-5447-a871-4b548cad5224" -version = "3.4.5" +version = "3.4.6" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -250,12 +318,13 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.0" +version = "2.3.1" [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -264,15 +333,19 @@ version = "2.3.0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.2" -weakdeps = ["Statistics"] +version = "1.9.3" [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" StaticArraysStatisticsExt = "Statistics" + [deps.StaticArrays.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -281,12 +354,12 @@ version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" +version = "1.10.0" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" +version = "7.2.1+1" [[deps.TOML]] deps = ["Dates"] @@ -299,7 +372,7 @@ uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" [[deps.Tensors]] -deps = ["ForwardDiff", "LinearAlgebra", "PrecompileTools", "SIMD", "StaticArrays", "Statistics"] +deps = ["ForwardDiff", "LinearAlgebra", "PrecompileTools", "Random", "SIMD", "StaticArrays", "Statistics"] path = ".." uuid = "48a634ad-e948-5137-8d70-aa71f2a747f4" version = "1.16.1" @@ -308,6 +381,15 @@ version = "1.16.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.TranscodingStreams]] +git-tree-sha1 = "3caa21522e7efac1ba21834a03734c57b4611c7e" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.10.4" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -318,19 +400,19 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" +version = "1.2.13+1" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" +version = "5.8.0+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" +version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" +version = "17.4.0+2" From 00bd57612b7585f16c8eccb0e9e8e171682361e3 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Fri, 15 Mar 2024 23:16:39 +0100 Subject: [PATCH 6/9] Fix oopsi --- src/constructors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/constructors.jl b/src/constructors.jl index c1024aa7..e168317d 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -109,7 +109,7 @@ _default_eltype(TT::Type{<:AbstractTensor}) = eltype(TT) function Random.rand(rng::Random.AbstractRNG, ::Random.SamplerType{TT}) where {TT <: Union{Tensor{order, dim}, SymmetricTensor{order, dim}}} where {order, dim} T = _default_eltype(TT) - return apply_all(get_base(TT), _ -> randn(rng, T)) + return apply_all(get_base(TT), _ -> rand(rng, T)) end # Always use the `SamplerType` as the value has no influence on the random generation. Random.Sampler(::Type{<:Random.AbstractRNG}, t::AllTensors, ::Random.Repetition) = Random.SamplerType{typeof(t)}() From 734a4b8de70b434ddf888109f5fc37ca976ca740 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Sat, 16 Mar 2024 14:45:25 +0100 Subject: [PATCH 7/9] Fix test --- test/test_misc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_misc.jl b/test/test_misc.jl index d310f980..2e5e37d9 100644 --- a/test/test_misc.jl +++ b/test/test_misc.jl @@ -27,7 +27,7 @@ for T in (Float32, Float64, F64), dim in (1,2,3), order in (1,2,3,4) @test rand(MersenneTwister(2), TT) ≈ rand(MersenneTwister(2), rand(TT)) # Check same value when given a value @inferred Vector{<:TT} rand(TT, 2) # Construct a Vector of random tensors if order == 1 - @test rand(MersenneTwister(1), Vec{dim}) ≈ rand(MersenneTwister(1), TT) + @test rand(MersenneTwister(1), Vec{dim, T}) ≈ rand(MersenneTwister(1), TT) end end # Special Vec constructor From 557be57fb9bc14ad47eb681b132b737d503d423e Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Sat, 16 Mar 2024 14:45:53 +0100 Subject: [PATCH 8/9] Add compat on Documenter in docs, updating to v1 to be done in separate pr --- docs/Manifest.toml | 80 +++------------------------------------------- docs/Project.toml | 3 ++ 2 files changed, 7 insertions(+), 76 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 1e134f60..921cd571 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,18 +2,13 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "063b18da735845e327fc97d98706d227f804d89d" +project_hash = "7cec80757d87f9c7f22da03f741d383cecfafe9e" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" version = "0.0.1" -[[deps.AbstractTrees]] -git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.5" - [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" @@ -24,12 +19,6 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.4" - [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -64,22 +53,16 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.9.3" [[deps.Documenter]] -deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "4a40af50e8b24333b9ec6892546d9ca5724228eb" +deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] +git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.3.0" +version = "0.27.25" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" -uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.5.0+0" - [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -93,18 +76,6 @@ weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] ForwardDiffStaticArraysExt = "StaticArrays" -[[deps.Git]] -deps = ["Git_jll"] -git-tree-sha1 = "04eff47b1354d702c3a85e8ab23d539bb7d5957e" -uuid = "d7ba0133-e1db-5d97-8f8c-041e4b3a1eb2" -version = "1.3.1" - -[[deps.Git_jll]] -deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "12945451c5d0e2d0dca0724c3a8d6448b46bbdf9" -uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" -version = "2.44.0+1" - [[deps.IOCapture]] deps = ["Logging", "Random"] git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" @@ -132,11 +103,6 @@ git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.4" -[[deps.LazilyInitializedFields]] -git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" -uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" -version = "1.2.2" - [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -164,12 +130,6 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -203,12 +163,6 @@ version = "0.5.13" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[deps.MarkdownAST]] -deps = ["AbstractTrees", "Markdown"] -git-tree-sha1 = "465a70f0fc7d443a00dcdc3267a497397b8a3899" -uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" -version = "0.1.2" - [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" @@ -241,23 +195,12 @@ deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" version = "0.8.1+2" -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "60e3045590bd104a16fefb12836c00c0ef8c7f8c" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.13+0" - [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.5+0" -[[deps.PCRE2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+1" - [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" @@ -293,12 +236,6 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[deps.RegistryInstances]] -deps = ["LazilyInitializedFields", "Pkg", "TOML", "Tar"] -git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" -uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3" -version = "0.1.0" - [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" @@ -381,15 +318,6 @@ version = "1.16.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[deps.TranscodingStreams]] -git-tree-sha1 = "3caa21522e7efac1ba21834a03734c57b4611c7e" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.4" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] - [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/docs/Project.toml b/docs/Project.toml index cb30e2e6..ca54c5ce 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,6 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" + +[compat] +Documenter = "0" \ No newline at end of file From 7b4e1ae61ec88a9ade92cac65bb38fe35bcec677 Mon Sep 17 00:00:00 2001 From: Knut Andreas Meyer Date: Sat, 16 Mar 2024 17:22:27 +0100 Subject: [PATCH 9/9] Fix type instability on 1.6 --- src/constructors.jl | 28 +++++++++++++++++++--------- test/test_misc.jl | 5 ++++- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index e168317d..db15e4dd 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -88,8 +88,10 @@ for TensorType in (SymmetricTensor, Tensor) end end -# zero, one, randn -for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T))), (:randn,:(()->randn(T)))) +# zero, one, randn. +# rand included here to make rand(::Type{AbstractTensor}) fast on julia 1.6. +# When 1.6 support is dropped, the general implementation below can be used instead. +for (op, el) in ((:zero, :(zero(T))), (:ones, :(one(T))), (:randn,:(()->randn(T))), (:rand,:(()->rand(T)))) for TensorType in (SymmetricTensor, Tensor) @eval begin @inline Base.$op(::Type{$TensorType{order, dim}}) where {order, dim} = $op($TensorType{order, dim, Float64}) @@ -101,18 +103,26 @@ end @eval @inline Base.$op(t::AllTensors) = $op(typeof(t)) end -# For `rand`, hook into Random -_default_eltype(::Type{Tensor{order, dim, T, M} where T}) where {order, dim, M} = Float64 -_default_eltype(::Type{Tensor{order, dim}}) where {order, dim} = Float64 -_default_eltype(::Type{SymmetricTensor{order, dim}}) where {order, dim} = Float64 -_default_eltype(TT::Type{<:AbstractTensor}) = eltype(TT) +# Helper to construct a fully specified tensor from at least the specification returned by get_base +function default_concrete_tensor_type(::Type{X}) where {X <: Union{Tensor, SymmetricTensor}} + TB = get_base(X) + T = eltype(X) === Any ? Float64 : eltype(X) + M = n_components(TB) + return TB{T, M} +end +# For `rand`, hook into Random function Random.rand(rng::Random.AbstractRNG, ::Random.SamplerType{TT}) where {TT <: Union{Tensor{order, dim}, SymmetricTensor{order, dim}}} where {order, dim} - T = _default_eltype(TT) - return apply_all(get_base(TT), _ -> rand(rng, T)) + TC = default_concrete_tensor_type(TT) + return apply_all(get_base(TT), _ -> rand(rng, eltype(TC)))::TC # typeassert needed on julia 1.6, but ok on 1.8 and later. end # Always use the `SamplerType` as the value has no influence on the random generation. Random.Sampler(::Type{<:Random.AbstractRNG}, t::AllTensors, ::Random.Repetition) = Random.SamplerType{typeof(t)}() +# Fix to make `rand([rng], ::Type{AbstractTensor}, d, dims...)` have a concrete eltype +function Random.rand(r::Random.AbstractRNG, ::Type{X}, dims::Dims) where {X <: Union{Tensor, SymmetricTensor}} + TC = default_concrete_tensor_type(X) + return Random.rand!(r, Array{TC}(undef, dims), X) +end @inline Base.fill(el::Number, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> el) @inline Base.fill(f::Function, S::Type{T}) where {T <: Union{Tensor, SymmetricTensor}} = apply_all(get_base(T), i -> f()) diff --git a/test/test_misc.jl b/test/test_misc.jl index 2e5e37d9..00904d33 100644 --- a/test/test_misc.jl +++ b/test/test_misc.jl @@ -25,7 +25,10 @@ for T in (Float32, Float64, F64), dim in (1,2,3), order in (1,2,3,4) TT = TensorType{order, dim, T} @test rand(MersenneTwister(1), TT) ≈ rand(MersenneTwister(1), TT) # Check that rng was actually used @test rand(MersenneTwister(2), TT) ≈ rand(MersenneTwister(2), rand(TT)) # Check same value when given a value - @inferred Vector{<:TT} rand(TT, 2) # Construct a Vector of random tensors + # type stability + (@inferred rand(TT, 2))::Vector{TT{Tensors.n_components(TensorType{order, dim})}} + @inferred rand(MersenneTwister(2), TT) + if order == 1 @test rand(MersenneTwister(1), Vec{dim, T}) ≈ rand(MersenneTwister(1), TT) end