Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add i_mean and i_sum, tweak i_mean_sum performance #52

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

johnnychen94
Copy link
Contributor

@johnnychen94 johnnychen94 commented Apr 29, 2021

I don't know whether @inbounds and @simd is allowed in NiLang.

x = rand(64, 64)
@btime i_mean_sum(0.0, 0.0, $x)
# 228.673 ns (0 allocations: 0 bytes) # PR
# 3.557 μs (0 allocations: 0 bytes) # master

Interestingly, i_mean_sum is more performant than i_mean. Is this because the additional uncompute process is not optimized?

julia> @btime i_mean(0.0, $x);
  476.871 ns (0 allocations: 0 bytes)

Edit: b7cc221 makes i_mean as performant as i_mean_sum. (Not quite understanding what's happening there..)

@GiggleLiu
Copy link
Owner

GiggleLiu commented Apr 29, 2021

Thanks for the PR. I think they are very good in general, but I feel not nessesary to have a i_mean function. It introduces too much overhead in order to umcompute a single scalar.

If we really want to handle the scalar garbage efficiently, you can use the design pattern that I used to implementation the reflector! function

struct Reflector{T,RT,VT<:AbstractVector{T}}
    ξ::T
    normu::RT
    sqnormu::RT
    r::T
    y::VT
end

@i function reflector!(R::Reflector{T,RT}, x::AbstractVector{T}) where {T,RT}
    n  length(x)
    @inbounds @invcheckoff if n != 0
        @zeros T ξ1
        @zeros RT normu sqnormu
        ξ1 += x[1]
        sqnormu += abs2(ξ1)
        for i = 2:n
            sqnormu += abs2(x[i])
        end
        if !iszero(sqnormu)
            normu += sqrt(sqnormu)
            if real(ξ1) < 0
                NEG(normu)
            end
            ξ1 += normu
            R.y[1] -= normu
            for i = 2:n
                R.y[i] += x[i] / ξ1
            end
            R.r += ξ1/normu
        end
        SWAP(R.ξ, ξ1)
        SWAP(R.normu, normu)
        SWAP(R.sqnormu, sqnormu)
    end
end

function alloc(::typeof(reflector!), x::AbstractVector{T}) where T
    RT = real(T)
    Reflector(zero(T), zero(RT), zero(RT), zero(T), zero(x))
end

NiLang.value(r::Reflector) = r.y

I use a structre to store outputs, alloc to allocate a output type for a function call, and finally use NiLang.value to access the value field.

To use it in another function, one can use R ← alloc(reflector!, args...)

@i function qr_pivoted!(res::QRPivotedRes, A::StridedMatrix{T}) where T
    m, n  size(A)
    @invcheckoff @inbounds for j = 1:min(m,n)
        # Find column with maximum norm in trailing submatrix
        jm  LinearAlgebra.indmaxcolumn(NiLang.value.(view(A, j:m, j:n))) + j - 1

        if jm != j
            # Flip elements in pivoting vector
            SWAP(res.jpvt[jm], res.jpvt[j])

            # Update matrix with
            for i = 1:m
                SWAP(A[i, jm], A[i, j])
            end
        end

        # Compute reflector of columns j
        R  alloc(reflector!, A |> subarray(j:m, j))  # can be automatically done
        vA  zeros(T, n-j)
        reflector!(R, A |> subarray(j:m, j))
        # Update trailing submatrix with reflector
        reflectorApply!(vA, R.y, R.r, A |> subarray(j:m, j+1:n))
        for i=1:length(R.y)
            SWAP(R.y[i], A[j+i-1, j])
        end
        PUSH!(res.reflectors, R)
        PUSH!(res.vAs, vA)
        PUSH!(res.jms, jm)
        R  _zero(Reflector{T,real(T),Vector{T}})
        vA  zeros(T, 0)
        jm  0
    end
    @inbounds for i=1:length(res.reflectors)
        res.τ[i] += res.reflectors[i].r
    end
    res.factors += A
end

Here, _zero is a magic function to create a zero cleared object for any type. You can find its definition in

Original reflector

https://github.com/JuliaLang/julia/blob/b692d9f444ba6e82f60623230b901fd622d9e5d6/stdlib/LinearAlgebra/src/generic.jl#L1491

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants