Skip to content

Commit

Permalink
tensors: Perform faster substitution on triangular matrix R of QR
Browse files Browse the repository at this point in the history
  • Loading branch information
onox committed May 26, 2024
1 parent 7955fda commit daa2916
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 17 deletions.
43 changes: 34 additions & 9 deletions orka_numerics/src/orka-numerics-tensors-operations.adb
Original file line number Diff line number Diff line change
Expand Up @@ -409,44 +409,69 @@ package body Orka.Numerics.Tensors.Operations is
end QR_For_Least_Squares;

function QR_Solve (R, Y : Tensor_Type; Determinancy : Matrix_Determinancy) return Tensor_Type
with Post => Is_Equal (QR_Solve'Result.Shape, Y.Shape, 1);
with Pre => Is_Equal (R.Shape, Y.Shape, 2),
Post => Is_Equal (QR_Solve'Result.Shape, Y.Shape, 1);

-- Perform back or forward substitution on a single row of a triangular matrix A and replace one element in X
--
-- This can be used to solve A * x = B for x. Parameters Start and Stop
-- are the column indices after (for upper triangular) or before (for lower triangular) the pivot position.
procedure Substitute_Row (A, B : Tensor_Type; X : in out Tensor_Type; Row_Index, Column_Index : Index_Type; Start, Stop : Natural) is
Pivot : constant Element := A.Get ([Row_Index, Column_Index]);

Xi : Tensor_Type := B.Get (Row_Index).Reshape([1, B.Columns]);
begin
if Stop >= Start then
declare
Indices_After_Pivot : constant Range_Type := (Start, Stop);

Row_From_A : constant Tensor_Type := A.Get (Tensor_Range'[(Row_Index, Row_Index), Indices_After_Pivot]);
Rows_From_B : constant Tensor_Type := X.Get (Indices_After_Pivot).Reshape ([Range_Length (Indices_After_Pivot), B.Columns]);
pragma Assert (Row_From_A.Shape = [Rows_From_B.Rows]);

Sum : constant Tensor_Type := Row_From_A * Rows_From_B;
pragma Assert (Sum.Shape = [1, Rows_From_B.Columns]);
begin
Xi := @ - Sum;
end;
end if;

X.Set (Row_Index, Xi.Flatten / Pivot);
end Substitute_Row;

function QR_Solve
(R, Y : Tensor_Type;
Determinancy : Matrix_Determinancy) return Tensor_Type
is
Ry : Tensor_Type := Concatenate (R, Y, Axis => 2);

Columns : constant Positive := R.Columns;
Size : constant Positive := Natural'Min (R.Rows, R.Columns);
-- Use the smallest Axis of R for the size of the reduced (square) version R1
-- without needing to extract it
--
-- R = [R1] (A is overdetermined) or R = [R1 0] (A is underdetermined)
-- [ 0]

Columns_Ry : constant Positive := Ry.Columns;
-- Empty is sufficient, but might fail for GPU backend which tries to materialize it in procedure Set
X : Tensor_Type := Zeros (Y.Shape);
begin
case Determinancy is
when Overdetermined =>
-- Backward phase: row reduce augmented matrix of R * x = (Q^T * b = y) to
-- reduced echelon form by performing back-substitution on Ry
-- (since R is upper triangular no forward phase is needed)
for Index in reverse 1 .. Size loop
Back_Substitute (Ry, Index, Index);
Substitute_Row (R, Y, X, Index, Index, Index + 1, R.Columns);
end loop;
when Underdetermined =>
-- Forward phase: row reduce augmented matrix of R^T * y = b
-- to reduced echelon form by performing forward-substitution on Ry
-- (R is actually R^T) (reduced because R^T is lower triangular)
for Index in 1 .. Size loop
Scale_Row (Ry, Index, 1.0 / Ry.Get ([Index, Index]));
Forward_Substitute (Ry, Index, Index);
Substitute_Row (R, Y, X, Index, Index, 1, Index - 1);
end loop;
when Unknown => raise Program_Error;
end case;

return Ry.Get (Tensor_Range'((1, Size), (Columns + 1, Columns_Ry)));
return X;
end QR_Solve;

overriding
Expand Down
1 change: 0 additions & 1 deletion orka_numerics/src/orka-numerics-tensors-operations.ads
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ private generic

with procedure Make_Upper_Triangular (Object : in out Tensor_Type; Offset : Integer := 0);

with procedure Scale_Row (Ab : in out Tensor_Type; I : Index_Type; Scale : Element);
with procedure Swap_Rows (Ab : in out Tensor_Type; I, J : Index_Type);

with procedure Forward_Substitute (Ab : in out Tensor_Type; Index, Pivot_Index : Index_Type);
Expand Down
4 changes: 3 additions & 1 deletion orka_numerics/src/orka-numerics-tensors.ads
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ package Orka.Numerics.Tensors is
or else raise Constraint_Error with
"Range start (" & Trim (Range_Type.Start) & ") > stop (" & Trim (Range_Type.Stop) & ")";

function Range_Length (Index : Range_Type) return Positive is (Index.Stop - Index.Start + 1);

type Tensor_Range is array (Tensor_Axis range <>) of Range_Type;

function Shape (Index : Tensor_Range) return Tensor_Shape
Expand Down Expand Up @@ -114,7 +116,7 @@ package Orka.Numerics.Tensors is
-- Return the value of a boolean matrix

function Get (Object : Tensor; Index : Range_Type) return Tensor is abstract
with Post'Class => Object.Axes = Get'Result.Axes;
with Post'Class => (if Index.Start = Index.Stop then Object.Axes in Get'Result.Axes | Get'Result.Axes + 1 else Object.Axes = Get'Result.Axes);

function Get (Object : Tensor; Index : Tensor_Range) return Tensor is abstract
with Pre'Class => Index'Length <= Object.Axes,
Expand Down
2 changes: 1 addition & 1 deletion orka_tensors_cpu/src/orka-numerics-tensors-simd_cpu.adb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ package body Orka.Numerics.Tensors.SIMD_CPU is
end Make_Upper_Triangular;

package Operations is new Orka.Numerics.Tensors.Operations
(CPU_Tensor, Make_Upper_Triangular, Scale_Row, Swap_Rows, Forward_Substitute, Back_Substitute,
(CPU_Tensor, Make_Upper_Triangular, Swap_Rows, Forward_Substitute, Back_Substitute,
Expression_Type, CPU_QR_Factorization, Create_QR, Q, R);

----------------------------------------------------------------------------
Expand Down
11 changes: 6 additions & 5 deletions orka_tensors_gpu/src/orka-numerics-tensors-cs_gpu.adb
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ package body Orka.Numerics.Tensors.CS_GPU is
end Make_Upper_Triangular;

package Operations is new Orka.Numerics.Tensors.Operations
(GPU_Tensor, Make_Upper_Triangular, Scale_Row, Swap_Rows, Forward_Substitute, Back_Substitute,
(GPU_Tensor, Make_Upper_Triangular, Swap_Rows, Forward_Substitute, Back_Substitute,
Expression_Type, GPU_QR_Factorization, Create_QR, Q, R);

----------------------------------------------------------------------------
Expand Down Expand Up @@ -961,6 +961,8 @@ package body Orka.Numerics.Tensors.CS_GPU is
GPU_Tensor (Tensor'Class'(Copy.Operation.Matrix_Operation.Left.Reference));
Source_Right : GPU_Tensor :=
GPU_Tensor (Tensor'Class'(Copy.Operation.Matrix_Operation.Right.Reference));

Source_Left_Columns : constant Natural := (if Source_Left.Axes > 1 then Source_Left.Columns else Source_Left.Rows);
begin
Buffers.Append (Materialize_Tensor (Source_Left));
Buffers.Append (Materialize_Tensor (Source_Right));
Expand All @@ -971,12 +973,11 @@ package body Orka.Numerics.Tensors.CS_GPU is
when 2 => Set_Shape (Kernel, Object.Shape);
when others => raise Not_Implemented_Yet; -- FIXME
end case;
Kernel.Uniform ("size").Set_UInt (Unsigned_32 (Source_Left.Columns));
pragma Assert (Source_Left.Columns = Source_Right.Rows);
Kernel.Uniform ("size").Set_UInt (Unsigned_32 (Source_Left_Columns));
pragma Assert (Source_Left_Columns = Source_Right.Rows);
pragma Assert (Source_Right.Axes = Object.Axes);

-- TODO Shouldn't Source_Left also be able to be a row vector, e.g. Axes <= 2?
pragma Assert (Source_Left.Axes = 2);
pragma Assert (Source_Left.Axes <= 2);
pragma Assert (Source_Right.Axes <= 2);
pragma Assert (Object.Axes <= 2);
end Initialize_Matrix_Matrix;
Expand Down

0 comments on commit daa2916

Please sign in to comment.