Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUTENSOR: Reduce amount of broadcasts compiled during tests. #2527

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions lib/cutensor/test/elementwise_binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ eltypes = [(Float16, Float16),
opAC = cuTENSOR.OP_ADD
dD = elementwise_binary_execute!(1, dA, indsA, opA, 1, dC, indsC, opC, dD, indsC, opAC)
D = collect(dD)
@test D ≈ permutedims(A, p) .+ C
@test D ≈ permutedims(A, p) + C

# using integers as indices
dD = elementwise_binary_execute!(1, dA, 1:N, opA, 1, dC, p, opC, dD, p, opAC)
D = collect(dD)
@test D ≈ permutedims(A, p) .+ C
@test D ≈ permutedims(A, p) + C

# multiplication as binary operator
opAC = cuTENSOR.OP_MUL
Expand All @@ -57,7 +57,7 @@ eltypes = [(Float16, Float16),
γ = rand(eltyD)
dD = elementwise_binary_execute!(α, dA, indsA, opA, γ, dC, indsC, opC, dD, indsC, opAC)
D = collect(dD)
@test D ≈ α .* conj.(permutedims(A, p)) .+ γ .* C
@test D ≈ α * conj.(permutedims(A, p)) + γ * C

# test in-place, and more complicated unary and binary operations
opA = eltyA <: Complex ? cuTENSOR.OP_IDENTITY : cuTENSOR.OP_SQRT
Expand All @@ -70,12 +70,12 @@ eltypes = [(Float16, Float16),
D = collect(dC)
if eltyD <: Complex
if eltyA <: Complex
@test D ≈ α .* permutedims(A, p) .+ γ .* conj.(C)
@test D ≈ α * permutedims(A, p) + γ * conj.(C)
else
@test D ≈ α .* sqrt.(eltyD.(permutedims(A, p))) .+ γ .* conj.(C)
@test D ≈ α * sqrt.(eltyD.(permutedims(A, p))) + γ * conj.(C)
end
else
@test D ≈ max.(α .* sqrt.(eltyD.(permutedims(A, p))), γ .* C)
@test D ≈ max.(α * sqrt.(eltyD.(permutedims(A, p))), γ * C)
end

# using CuTensor type
Expand All @@ -85,24 +85,24 @@ eltypes = [(Float16, Float16),
ctC = CuTensor(dC, indsC)
ctD = ctA + ctC
hD = collect(ctD.data)
@test hD ≈ permutedims(A, p) .+ C
@test hD ≈ permutedims(A, p) + C
ctD = ctA - ctC
hD = collect(ctD.data)
@test hD ≈ permutedims(A, p) .- C
@test hD ≈ permutedims(A, p) - C

α = rand(eltyD)
ctC_copy = copy(ctC)
ctD = LinearAlgebra.axpy!(α, ctA, ctC_copy)
@test ctD == ctC_copy
hD = collect(ctD.data)
@test hD ≈ α.*permutedims(A, p) .+ C
@test hD ≈ α * permutedims(A, p) + C

γ = rand(eltyD)
ctC_copy = copy(ctC)
ctD = LinearAlgebra.axpby!(α, ctA, γ, ctC_copy)
@test ctD == ctC_copy
hD = collect(ctD.data)
@test hD ≈ α.*permutedims(A, p) .+ γ.*C
@test hD ≈ α * permutedims(A, p) + γ * C
end
end

Expand Down
43 changes: 19 additions & 24 deletions lib/cutensor/test/elementwise_trinary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,35 @@ eltypes = [(Float16, Float16, Float16),
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
1, dC, indsC, opC, dD, indsC, opAB, opABC)
D = collect(dD)
@test D ≈ permutedims(A, pA) .+ permutedims(B, pB) .+ C
@test D ≈ permutedims(A, pA) + permutedims(B, pB) + C

# using integers as indices
dD = elementwise_trinary_execute!(1, dA, ipA, opA, 1, dB, ipB, opB,
1, dC, 1:N, opC, dD, 1:N, opAB, opABC)
D = collect(dD)
@test D ≈ permutedims(A, pA) .+ permutedims(B, pB) .+ C
@test D ≈ permutedims(A, pA) + permutedims(B, pB) + C

# multiplication as binary operator
opAB = cuTENSOR.OP_MUL
opABC = cuTENSOR.OP_ADD
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
1, dC, indsC, opC, dD, indsC, opAB, opABC)
D = collect(dD)
@test D ≈ (eltyD.(permutedims(A, pA)) .* eltyD.(permutedims(B, pB))) .+ C
@test D ≈ (eltyD.(permutedims(A, pA)) .* eltyD.(permutedims(B, pB))) + C

opAB = cuTENSOR.OP_ADD
opABC = cuTENSOR.OP_MUL
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
1, dC, indsC, opC, dD, indsC, opAB, opABC)
D = collect(dD)
@test D ≈ (eltyD.(permutedims(A, pA)) .+ eltyD.(permutedims(B, pB))) .* C
@test D ≈ (eltyD.(permutedims(A, pA)) + eltyD.(permutedims(B, pB))) .* C

opAB = cuTENSOR.OP_MUL
opABC = cuTENSOR.OP_MUL
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
1, dC, indsC, opC, dD, indsC, opAB, opABC)
D = collect(dD)
@test D ≈ eltyD.(permutedims(A, pA)) .*
eltyD.(permutedims(B, pB)) .* C
@test D ≈ eltyD.(permutedims(A, pA)) .* eltyD.(permutedims(B, pB)) .* C

# with non-trivial coefficients and conjugation
α = rand(eltyD)
Expand All @@ -88,24 +87,22 @@ eltypes = [(Float16, Float16, Float16),
dD = elementwise_trinary_execute!(α, dA, indsA, opA, β, dB, indsB, opB,
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
D = collect(dD)
@test D ≈ α .* conj.(permutedims(A, pA)) .+ β .* permutedims(B, pB) .+ γ .* C
@test D ≈ α * conj.(permutedims(A, pA)) + β * permutedims(B, pB) + γ * C

opB = eltyB <: Complex ? cuTENSOR.OP_CONJ : cuTENSOR.OP_IDENTITY
opAB = cuTENSOR.OP_ADD
opABC = cuTENSOR.OP_ADD
dD = elementwise_trinary_execute!(α, dA, indsA, opA, β, dB, indsB, opB,
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
D = collect(dD)
@test D ≈ α .* conj.(permutedims(A, pA)) .+
β .* conj.(permutedims(B, pB)) .+ γ .* C

@test D ≈ α * conj.(permutedims(A, pA)) + β * conj.(permutedims(B, pB)) + γ * C
opA = cuTENSOR.OP_IDENTITY
opAB = cuTENSOR.OP_MUL
opABC = cuTENSOR.OP_ADD
dD = elementwise_trinary_execute!(α, dA, indsA, opA, β, dB, indsB, opB,
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
D = collect(dD)
@test D ≈ α .* permutedims(A, pA) .* β .* conj.(permutedims(B, pB)) .+ γ .* C
@test D ≈ * permutedims(A, pA)) .* * conj.(permutedims(B, pB))) + γ * C

# test in-place, and more complicated unary and binary operations
opA = eltyA <: Complex ? cuTENSOR.OP_IDENTITY : cuTENSOR.OP_SQRT
Expand All @@ -122,24 +119,22 @@ eltypes = [(Float16, Float16, Float16),
D = collect(dD)
if eltyD <: Complex
if eltyA <: Complex && eltyB <: Complex
@test D ≈ α .* permutedims(A, pA) .* β .* permutedims(B, pB) .+
γ .* conj.(C)
@test D ≈ * permutedims(A, pA)) .*
(β * permutedims(B, pB)) + γ * conj.(C)
elseif eltyB <: Complex
@test D ≈ α .* sqrt.(eltyD.(permutedims(A, pA))) .*
β .* permutedims(B, pB) .+ γ .* conj.(C)
@test D ≈ * sqrt.(eltyD.(permutedims(A, pA)))) .*
* permutedims(B, pB)) + γ * conj.(C)
elseif eltyB <: Complex
@test D ≈ α .* permutedims(A, pA) .*
β .* sqrt.(eltyD.(permutedims(B, pB))) .+
γ .* conj.(C)
@test D ≈ (α * permutedims(A, pA)) .*
(β * sqrt.(eltyD.(permutedims(B, pB)))) + γ * conj.(C)
else
@test D ≈ α .* sqrt.(eltyD.(permutedims(A, pA))) .*
β .* sqrt.(eltyD.(permutedims(B, pB))) .+
γ .* conj.(C)
@test D ≈ (α * sqrt.(eltyD.(permutedims(A, pA)))) .*
(β * sqrt.(eltyD.(permutedims(B, pB)))) + γ * conj.(C)
end
else
@test D ≈ max.(min.(α .* sqrt.(eltyD.(permutedims(A, pA))),
β .* sqrt.(eltyD.(permutedims(B, pB)))),
γ .* C)
@test D ≈ max.(min.(α * sqrt.(eltyD.(permutedims(A, pA))),
β * sqrt.(eltyD.(permutedims(B, pB)))),
γ * C)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/cutensor/test/reductions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ eltypes = [(Float16, Float16),
γ = rand(eltyC)
dC = reduce!(α, dA, indsA, opA, γ, dC, indsC, opC, opReduce)
@test reshape(collect(dC), (dimsC..., ones(Int,NA-NC)...)) ≈
α .* conj.(sum(permutedims(A, p); dims = ((NC+1:NA)...,))) .+ γ .* C
α * conj.(sum(permutedims(A, p); dims = ((NC+1:NA)...,))) + γ * C
end
end

Expand Down