Skip to content

Commit

Permalink
Extend macros with rand to support custom samplers (#1210)
Browse files Browse the repository at this point in the history
* Extend macros with rand to support custom samplers

* Fix tests

* Update Project.toml

Co-authored-by: Yuto Horikawa <[email protected]>

* Update test/arraymath.jl

Co-authored-by: Yuto Horikawa <[email protected]>

* Update src/SVector.jl

Co-authored-by: Yuto Horikawa <[email protected]>

* Update src/SMatrix.jl

Co-authored-by: Yuto Horikawa <[email protected]>

* Update `SArray` macro

* fix reported issue; support rng in SArray and SMatrix

* Code suggestions for #1210 (#1213)

* move `ex.args[2] isa Integer`

* split `if` block

* simplify :zeros and :ones

* refactor :rand

* refactor :randn and :randexp

* update comments

* add _isnonnegvec

* update with `_isnonnegvec`

* add `_isnonnegvec(args, n)` method to check the size of `args`

* fix `@SArray` for `@SArray rand(rng,T,dim)` etc.

* update comments

* update `@SVector` macro

* update `@SMatrix`

* update `@SVector`

* update `@SArray`

* introduce `fargs` variable

* avoid `_isnonnegvec` in `static_matrix_gen`

* avoid `_isnonnegvec` in `static_vector_gen`

* remove unnecessary `_isnonnegvec`

* add `_rng()` function

* update tests on `@SVector` macro

* update tests on `@MVector` macro

* organize test/MMatrix.jl and test/SMatrix.jl

* organize test/MMatrix.jl and test/SMatrix.jl

* update with broken tests

* organize test/MMatrix.jl and test/SMatrix.jl for `rand*` functions

* fix around `broken` key for `@test` macro

* fix zero-length tests

* update `test/SArray.jl` to match `test/MArray.jl`

* update tests for `@SArray ones` etc.

* add supports for `@SArray ones(3-1,2)` etc.

* move block for `fill`

* update macro `@SArray rand(rng,2,3)` to use ordinary dispatches

* update around `@SArray randn` etc.

* remove unnecessary dollars

* simplify `@SArray fill`

* add `@testset "expand_error"`

* update tests for `@SArray rand(...)` etc.

* fix bug in `rand*_with_Val`

* cleanup tests

* update macro `@SMatrix rand(rng,2,3)` to use ordinary dispatches

* update macro `@SVector rand(rng,3)` to use ordinary dispatches

* move block for `fill`

* simplify `_randexp_with_Val`

---------

Co-authored-by: Yuto Horikawa <[email protected]>
  • Loading branch information
mateuszbaran and hyrodium authored Jan 3, 2024
1 parent 3fd8fb9 commit e2d772f
Show file tree
Hide file tree
Showing 13 changed files with 826 additions and 179 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.8.2"
version = "1.9.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
113 changes: 94 additions & 19 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,65 @@ function parse_cat_ast(ex::Expr)
cat_any(Val(maxdim), Val(catdim), nargs)
end

#=
For example,
* `@SArray rand(2, 3, 4)`
* `@SArray rand(rng, 3, 4)`
will be expanded to the following.
* `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))`
* `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))`
The function `_int2val` is required to avoid the following case.
* `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))`
* `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))`
Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error.
=#
_int2val(x::Int) = Val(x)
_int2val(::Any) = nothing
# @SArray zeros(...)
_zeros_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = zeros(SA{Tuple{n1, ns...}})
_zeros_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = zeros(SA{Tuple{ns...}, T})
# @SArray ones(...)
_ones_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = ones(SA{Tuple{n1, ns...}})
_ones_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = ones(SA{Tuple{ns...}, T})
# @SArray rand(...)
_rand_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = rand(SA{Tuple{n1,n2,ns...}})
_rand_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, T, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_rand_with_Val(::Type{SA}, sampler, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, ns...), SA{Tuple{n1, ns...}, Random.gentype(sampler)})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(rng, Float64, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, T, Size(ns...), SA{Tuple{ns...}, T})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, sampler, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, sampler, Size(ns...), SA{Tuple{ns...}, Random.gentype(sampler)})
# @SArray randn(...)
_randn_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randn(SA{Tuple{n1,n2,ns...}})
_randn_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_randn_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_randn_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randn(rng, Size(ns...), SA{Tuple{ns...}, T})
# @SArray randexp(...)
_randexp_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randexp(SA{Tuple{n1,n2,ns...}})
_randexp_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randexp(rng, Size(ns...), SA{Tuple{ns...}, T})

escall(args) = Iterators.map(esc, args)
function _isnonnegvec(args)
length(args) == 0 && return false
all(isa.(args, Integer)) && return all(args .≥ 0)
return false
end
function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
if !isa(ex, Expr)
error("Bad input for @$SA")
end
head = ex.head
if head === :vect # vector
return :($SA{$Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
return :($SA{Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
elseif head === :ref # typed, vector
return :($SA{$Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
return :($SA{Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat
args = parse_cat_ast(ex)
return :($SA{$Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
return :($SA{Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat
args = parse_cat_ast(ex)
return :($SA{$Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
return :($SA{Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
elseif head === :comprehension
if length(ex.args) != 1
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
Expand All @@ -173,7 +216,7 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
return quote
let
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
$SA{$Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
$SA{Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
end
end
elseif head === :typed_comprehension
Expand All @@ -192,26 +235,58 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
return quote
let
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
$SA{$Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
$SA{Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
end
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 1
f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)")
return :($f($SA{$Tuple{},$Float64}))
end
return quote
if isa($(esc(ex.args[2])), DataType)
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
else
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
end
fargs = ex.args[2:end]
if f === :zeros || f === :ones
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 0
# for calls like `zeros()`
return :($f($SA{Tuple{},$Float64}))
elseif _isnonnegvec(fargs)
# for calls like `zeros(dims...)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
else
# for calls like `zeros(type)`
# for calls like `zeros(type, dims...)`
return :($_f_with_Val($SA, $(esc(fargs[1])), Val($(esc(fargs[1]))), Val(tuple($(escall(fargs[2:end])...)))))
end
elseif f === :fill
length(ex.args) == 1 && error("@$SA got bad expression: $(ex)")
return :($f($(esc(ex.args[2])), $SA{$Tuple{$(escall(ex.args[3:end])...)}}))
# for calls like `fill(value, dims...)`
return :($f($(esc(fargs[1])), $SA{Tuple{$(escall(fargs[2:end])...)}}))
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 0
# No support for `@SArray rand()`
error("@$SA got bad expression: $(ex)")
elseif _isnonnegvec(fargs)
# for calls like `rand(dims...)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
elseif length(fargs) 2
# for calls like `rand(dim1, dim2, dims...)`
# for calls like `rand(type, dim1, dims...)`
# for calls like `rand(sampler, dim1, dims...)`
# for calls like `rand(rng, dim1, dims...)`
# for calls like `rand(rng, type, dims...)`
# for calls like `rand(rng, sampler, dims...)`
# for calls like `randn(dim1, dim2, dims...)`
# for calls like `randn(type, dim1, dims...)`
# for calls like `randn(rng, dim1, dims...)`
# for calls like `randn(rng, type, dims...)`
# for calls like `randexp(dim1, dim2, dims...)`
# for calls like `randexp(type, dim1, dims...)`
# for calls like `randexp(rng, dim1, dims...)`
# for calls like `randexp(rng, type, dims...)`
return :($_f_with_Val($SA, $(esc(fargs[1])), $(esc(fargs[2])), _int2val($(esc(fargs[1]))), _int2val($(esc(fargs[2]))), Val(tuple($(escall(fargs[3:end])...)))))
elseif length(fargs) == 1
# for calls like `rand(dim)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
else
error("@$SA got bad expression: $(ex)")
end
else
error("@$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
Expand Down
64 changes: 54 additions & 10 deletions src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ function check_matrix_size(x::Tuple, T = :S)
x1, x2
end

# @SMatrix rand(...)
_rand_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2})
_rand_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, T, Size(n1, n2), SM{n1, n2, T})
_rand_with_Val(::Type{SM}, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
_rand_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2, T})
_rand_with_Val(::Type{SM}, rng::AbstractRNG, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(rng, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
# @SMatrix randn(...)
_randn_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2})
_randn_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randn(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
_randn_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2, T})
# @SMatrix randexp(...)
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2})
_randexp_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randexp(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2, T})

function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM}
if !isa(ex, Expr)
error("Bad input for @$SM")
Expand Down Expand Up @@ -69,22 +84,51 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 3
return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base
elseif length(ex.args) == 4
return :($f($SM{$(escall(ex.args[[3,4,2]])...)}))
fargs = ex.args[2:end]
if f === :zeros || f === :ones
if length(fargs) == 2
# for calls like `zeros(dim1, dim2)`
return :($f($SM{$(escall(fargs)...)}))
elseif length(fargs[2:end]) == 2
# for calls like `zeros(type, dim1, dim2)`
return :($f($SM{$(escall(fargs[2:end])...), $(esc(fargs[1]))}))
else
error("@$SM expected a 2-dimensional array expression")
error("@$SM got bad expression: $(ex)")
end
elseif ex.args[1] === :fill
if length(ex.args) == 4
return :($f($(esc(ex.args[2])), $SM{$(escall(ex.args[3:4])...)}))
elseif f === :fill
# for calls like `fill(value, dim1, dim2)`
if length(fargs[2:end]) == 2
return :($f($(esc(fargs[1])), $SM{$(escall(fargs[2:end])...)}))
else
error("@$SM expected a 2-dimensional array expression")
end
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 2
# for calls like `rand(dim1, dim2)`
# for calls like `randn(dim1, dim2)`
# for calls like `randexp(dim1, dim2)`
return :($f($SM{$(escall(fargs)...)}))
elseif length(fargs) == 3
# for calls like `rand(rng, dim1, dim2)`
# for calls like `rand(type, dim1, dim2)`
# for calls like `rand(sampler, dim1, dim2)`
# for calls like `randn(rng, dim1, dim2)`
# for calls like `randn(type, dim1, dim2)`
# for calls like `randexp(rng, dim1, dim2)`
# for calls like `randexp(type, dim1, dim2)`
return :($_f_with_Val($SM, $(esc(fargs[1])), Val($(esc(fargs[2]))), Val($(esc(fargs[3])))))
elseif length(fargs) == 4
# for calls like `rand(rng, type, dim1, dim2)`
# for calls like `rand(rng, sampler, dim1, dim2)`
# for calls like `randn(rng, type, dim1, dim2)`
# for calls like `randexp(rng, type, dim1, dim2)`
return :($_f_with_Val($SM, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3]))), Val($(esc(fargs[4])))))
else
error("@$SM got bad expression: $(ex)")
end
else
error("@$SM only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
error("@$SM only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
else
error("Bad input for @$SM")
Expand Down
64 changes: 54 additions & 10 deletions src/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ function check_vector_length(x::Tuple, T = :S)
length(x) >= 1 ? x[1] : 1
end

# @SVector rand(...)
_rand_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = rand(rng, SV{n})
_rand_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, T, Size(n), SV{n, T})
_rand_with_Val(::Type{SV}, sampler, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, sampler, Size(n), SV{n, Random.gentype(sampler)})
_rand_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = rand(rng, SV{n, T})
_rand_with_Val(::Type{SV}, rng::AbstractRNG, sampler, ::Val{n}) where {SV, n} = _rand(rng, sampler, Size(n), SV{n, Random.gentype(sampler)})
# @SVector randn(...)
_randn_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randn(rng, SV{n})
_randn_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randn(Random.GLOBAL_RNG, Size(n), SV{n, T})
_randn_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randn(rng, SV{n, T})
# @SVector randexp(...)
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randexp(rng, SV{n})
_randexp_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randexp(Random.GLOBAL_RNG, Size(n), SV{n, T})
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randexp(rng, SV{n, T})

function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV}
if !isa(ex, Expr)
error("Bad input for @$SV")
Expand Down Expand Up @@ -74,22 +89,51 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 2
return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base
elseif length(ex.args) == 3
return :($f($SV{$(escall(ex.args[3:-1:2])...)}))
fargs = ex.args[2:end]
if f === :zeros || f === :ones
if length(fargs) == 1
# for calls like `zeros(dim)`
return :($f($SV{$(esc(fargs[1]))}))
elseif length(fargs) == 2
# for calls like `zeros(type, dim)`
return :($f($SV{$(esc(fargs[2])), $(esc(fargs[1]))}))
else
error("@$SV expected a 1-dimensional array expression")
error("@$SV got bad expression: $(ex)")
end
elseif ex.args[1] === :fill
if length(ex.args) == 3
return :($f($(esc(ex.args[2])), $SV{$(esc(ex.args[3]))}))
elseif f === :fill
# for calls like `fill(value, dim)`
if length(fargs) == 2
return :($f($(esc(fargs[1])), $SV{$(esc(fargs[2]))}))
else
error("@$SV expected a 1-dimensional array expression")
end
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 1
# for calls like `rand(dim)`
# for calls like `randn(dim)`
# for calls like `randexp(dim)`
return :($f($SV{$(escall(fargs)...)}))
elseif length(fargs) == 2
# for calls like `rand(rng, dim)`
# for calls like `rand(type, dim)`
# for calls like `rand(sampler, dim)`
# for calls like `randn(rng, dim)`
# for calls like `randn(type, dim)`
# for calls like `randexp(rng, dim)`
# for calls like `randexp(type, dim)`
return :($_f_with_Val($SV, $(esc(fargs[1])), Val($(esc(fargs[2])))))
elseif length(fargs) == 3
# for calls like `rand(rng, type, dim)`
# for calls like `rand(rng, sampler, dim)`
# for calls like `randn(rng, type, dim)`
# for calls like `randexp(rng, type, dim)`
return :($_f_with_Val($SV, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3])))))
else
error("@$SV got bad expression: $(ex)")
end
else
error("@$SV only supports the zeros(), ones(), rand(), randn() and randexp() functions.")
error("@$SV only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
else
error("Use @$SV [a,b,c], @$SV Type[a,b,c] or a comprehension like @$SV [f(i) for i = i_min:i_max]")
Expand Down
4 changes: 2 additions & 2 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ end

@inline rand(rng::AbstractRNG, range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(rng, range, Size(SA), SA)
@inline rand(range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(Random.GLOBAL_RNG, range, Size(SA), SA)
@generated function _rand(rng::AbstractRNG, range::AbstractArray, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, range)) for i = 1:prod(s)]
@generated function _rand(rng::AbstractRNG, X, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, X)) for i = 1:prod(s)]
return quote
@_inline_meta
$SA(tuple($(v...)))
Expand Down
Loading

2 comments on commit e2d772f

@hyrodium
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/98096

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.9.0 -m "<description of version>" e2d772f9767abdcab20ce7ae6927dc25dc38714b
git push origin v1.9.0

Please sign in to comment.