Skip to content

Commit

Permalink
feat: matmul
Browse files Browse the repository at this point in the history
This is pretty close to the NumPy version. We don't handle 1D
arrays yet, but N-D seems to work. The method is inefficient,
with a broadcast/reshape/reshape, but is conceptually
simple.
  • Loading branch information
seanmcl committed Feb 24, 2025
1 parent 4cc8c0b commit 9f31f0b
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 40 deletions.
2 changes: 2 additions & 0 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ instance : ToString Shape where

def empty : Shape := Shape.mk []

def append (shape : Shape) (dims : List Nat) : Shape := Shape.mk (shape.val ++ dims)

--! The number of elements in a tensor. All that's needed is the shape for this calculation.
-- TODO: Put this in the struct?
def count (shape : Shape) : Nat := natProd shape.val
Expand Down
8 changes: 8 additions & 0 deletions TensorLib/Tensor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,14 @@ def broadcastTo (arr : Tensor) (shape : Shape) : Err Tensor :=
let strides <- broadcastStrides (arr.shape.val.zip arr.strides) shape
.ok $ Tensor.mk arr.dtype shape arr.data arr.startIndex strides

def broadcast (arr1 : Tensor) (arr2 : Tensor) : Err (Tensor × Tensor) :=
match Broadcast.broadcast { left := arr1.shape, right := arr2.shape } with
| none => .error "Can't broadcast"
| some shape => do
let arr1 <- arr1.broadcastTo shape
let arr2 <- arr2.broadcastTo shape
return (arr1, arr2)

class Element (a : Type) where
dtype : Dtype
itemsize : Nat
Expand Down
211 changes: 171 additions & 40 deletions TensorLib/Ufunc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import TensorLib.Broadcast
import TensorLib.Common
import TensorLib.Index
import TensorLib.Tensor
import TensorLib.Broadcast

/-!
Universal functions: https://numpy.org/doc/stable/reference/ufuncs.html
Expand All @@ -16,6 +17,8 @@ namespace TensorLib
namespace Tensor
namespace Ufunc

def DEBUG : Bool := false

private def binop (a : Type) [Element a] (x y : Tensor) (op : a -> a -> Err a) : Err Tensor :=
match Broadcast.broadcast { left := x.shape, right := y.shape } with
| .none => .error s!"Can't broadcast shapes ${x.shape} with {y.shape}"
Expand Down Expand Up @@ -101,55 +104,181 @@ def sum (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) (axes : Option (L
termination_by axes.length
loop axes res

private def hasTree0 (a : Type) [BEq a] [Element a] (arr : Tensor) (n : a) : Bool :=
arr.shape.val == [] && match Element.getPosition arr 0 with
| .error _ => false
| .ok (v : a) => v == n
/-
Implements the dot product. np.dot for 1-D arrays.
np.dot supports a bunch of other cases, but all of them are reducible to other operations like
multiplication by a scalar, matrix multiplication, etc. While we'd like to stay close to NumPy,
we also would like the author to use the simplest, most natural operations possible.
-/
def dot (a : Type) [Add a] [Mul a] [Zero a] [Element a] (x y : Tensor) : Err Tensor := do
if x.dtype != y.dtype then .error "Expected same dtype" else
let (xd1, yd1) <- match x.shape.val, y.shape.val with
| [xd1], [yd1] => .ok (xd1, yd1)
| [], _ | _, [] => .error "While allowed in NumPy, please use scalar multiplication for array scalars"
| _, _ => .error "While allowed in NumPy when the dimensions work out, please use matmul for this use case"
if xd1 != yd1 then .error "dot: reduction dimension mismatch" else
let mut acc : a := 0
for i in [0:xd1] do
let u <- Element.getDimIndex x [i]
let v <- Element.getDimIndex y [i]
acc := acc + u * v
return Element.arrayScalar a acc

-- The usual 2D matmul
private def matmul2 (a : Type) [Add a] [Mul a] [Zero a] [Element a] (x y : Tensor) : Err Tensor := do
if x.dtype != y.dtype then .error "Expected same dtype" else
let (xd1, xd2, yd1, yd2) <- match x.shape.val, y.shape.val with
| [xd1, xd2], [yd1, yd2] => .ok (xd1, xd2, yd1, yd2)
| _, _ => .error "Expected 2d arrays"
if xd2 != yd1 then .error "matmul2: reduction dimension mismatch" else
let mut res := Tensor.zeros x.dtype (Shape.mk [xd1, yd2])
for i in [0:xd1] do
for j in [0:yd2] do
let mut acc : a := 0
for k in [0:xd2] do
let u <- Element.getDimIndex x [i, k]
let v <- Element.getDimIndex y [k, j]
acc := acc + u * v
res <- Element.setDimIndex res [i, j] acc
return res

/-
NumPy matmul handles many variants of dimensions.
https://numpy.org/doc/2.1/reference/generated/numpy.matmul.html
For now we'll handle
* 2x2
* NxM where N,M >= 2, where when N or M are greater than 2, we just have a lot of 2x2 matrics
that we multiply together and put in the result in the correctly-shaped slots.
TODO: I'm not sure what to do with the axis/axes arguments.
-/
def matmul (a : Type) [Add a] [Mul a] [Zero a] [Element a] (x y : Tensor) : Err Tensor := do
if x.dtype != y.dtype then .error "Expected same dtype" else
-- The last two dimensions of each array must line up matmul-style
let (xd1, xd2, xds, yd1, yd2, yds) <- match x.shape.val.reverse, y.shape.val.reverse with
| [], _ | _, [] => .error "array scalars not allowed"
| [_], _ | _, [_] => .error "NumPy-like 1D matmul not yet implemented"
| xd2 :: xd1 :: xds, yd2 :: yd1 :: yds => .ok (xd1, xd2, xds.reverse, yd1, yd2, yds.reverse)
if xd2 != yd1 then .error "matmulN: reduction dimension mismatch" else
-- Broadcast the prefixes (not including the final 2 dimensions, which wouldn't match under
-- typical broadcast rules, (e.g. [4,2] vs [2,3] to produce a [4,3] matrix.)
match Broadcast.broadcast { left := Shape.mk xds, right := Shape.mk yds } with
| none => .error "can't broadcast prefix"
| some (Shape.mk []) => matmul2 a x y
| some prefixShape =>
-- First broadcast to get the correct sizes and strides ...
let xShape := prefixShape.append [xd1, xd2]
let yShape := prefixShape.append [yd1, yd2]
let x <- x.broadcastTo xShape
let y <- y.broadcastTo yShape
-- then flatten
let prefixSize := prefixShape.count
let xShape := Shape.mk [prefixSize, xd1, xd2]
let yShape := Shape.mk [prefixSize, yd1, yd2]
let x <- x.reshape xShape
let y <- y.reshape yShape
-- then loop
let resShape := Shape.mk [prefixSize, xd1, yd2]
let mut res := Tensor.zeros x.dtype resShape
for i in [0:prefixSize] do
let index := [Index.NumpyItem.int i]
let (x', _) <- Index.apply index x
let (y', _) <- Index.apply index y
let v <- matmul2 a x' y'
res <- Index.assign a res index v
-- now reshape
let resShape := prefixShape.append [xd1, yd2]
res.reshape resShape

section Test
open Tensor.Format.Tree

private def hasTree1 (a : Type) [Repr a] [BEq a] [Element a] (arr : Tensor) (xs : List a) : Bool :=
arr.shape.val == [xs.length] && match arr.toTree a with
| .error _ => false
| .ok v => v == .root xs

/-
# x = np.arange(10).reshape(2, 5)
# y = np.arange(10).reshape(5, 2)
# np.matmul(x, y)
array([[ 60, 70],
[160, 195]])
-/
#guard
let tp := BV8
let arr1 := get! $ (Element.arange tp 10).reshape (Shape.mk [2, 5])
let arr2 := get! $ (Element.arange tp 10).reshape (Shape.mk [5, 2])
let arr3 := get! $ matmul2 tp arr1 arr2
arr3.toTree! tp == .node [.root [60, 70], .root [160, 195]]

#guard
let typ := BV8
let arr := get! $ (Element.arange typ 10).reshape (Shape.mk [2, 5])
!(sum typ arr (.some [0, 1, 0])).isOk &&
!(sum typ arr (.some [0, 0, 1])).isOk &&
!(sum typ arr (.some [7])).isOk
let tp := BV8
let arr := get! $ (Element.arange tp 10).reshape (Shape.mk [2, 5])
!(sum tp arr (.some [0, 1, 0])).isOk &&
!(sum tp arr (.some [0, 0, 1])).isOk &&
!(sum tp arr (.some [7])).isOk

-- [[0, 1, 2, 3, 4],
-- [5, 6, 7, 8, 9]]
#guard
let typ := BV8
let arr := get! $ (Element.arange typ 10).reshape (Shape.mk [2, 5])
let x0 := get! $ sum typ arr .none
let x1 := get! $ sum typ arr (.some [])
let x2 := get! $ sum typ arr (.some [0])
let x3 := get! $ sum typ arr (.some [1])
let x4 := get! $ sum typ arr (.some [1, 0])
let x5 := get! $ sum typ arr (.some [0, 1])
let res :=
hasTree0 typ x0 45 &&
hasTree0 typ x1 45 &&
hasTree1 typ x2 [5, 7, 9, 11, 13] &&
hasTree1 typ x3 [10, 35] &&
hasTree0 typ x4 45 &&
hasTree0 typ x5 45
res
let tp := BV8
let arr := get! $ (Element.arange tp 10).reshape (Shape.mk [2, 5])
let x0 := get! $ sum tp arr .none
let x1 := get! $ sum tp arr (.some [])
let x2 := get! $ sum tp arr (.some [0])
let x3 := get! $ sum tp arr (.some [1])
let x4 := get! $ sum tp arr (.some [1, 0])
let x5 := get! $ sum tp arr (.some [0, 1])
x0.toTree! tp == .root [45]
&& x1.toTree! tp == .root [45]
&& x2.toTree! tp == .root [5, 7, 9, 11, 13]
&& x3.toTree! tp == .root [10, 35]
&& x4.toTree! tp == .root [45]
&& x5.toTree! tp == .root [45]

#guard
let typ := BV8
let x := Element.arange typ 10
let arr := get! $ add typ x x
hasTree1 typ arr [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
let tp := BV8
let x := Element.arange tp 10
let arr := get! $ add tp x x
arr.toTree! tp == .root [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]

#guard
let typ := BV8
let x := Element.arange typ 10
let y := Element.arrayScalar typ 7
let arr := get! $ add typ x y
hasTree1 typ arr [7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
let tp := BV8
let x := Element.arange tp 10
let y := Element.arrayScalar tp 7
let arr := get! $ add tp x y
arr.toTree! tp == .root [7, 8, 9, 10, 11, 12, 13, 14, 15, 16]

#guard
let tp := BV8
let x := (Element.arange tp 6).reshape! (Shape.mk [2, 3])
let y := (Element.arange tp 6).reshape! (Shape.mk [3, 2])
let z := get! $ matmul tp x y
z.toTree! tp == .node [.root [10, 13], .root [28, 40]]

#guard
let tp := BV8
let x := (Element.arange tp 6).reshape! (Shape.mk [1, 2, 3])
let y := (Element.arange tp 6).reshape! (Shape.mk [1, 3, 2])
let z := get! $ matmul tp x y
z.toTree! tp == .node [.node [.root [10, 13], .root [28, 40]]]

#guard
let tp := BV8
let x := (Element.arange tp 12).reshape! (Shape.mk [2, 2, 3])
let y := (Element.arange tp 6).reshape! (Shape.mk [1, 3, 2])
let z := get! $ matmul tp x y
z.toTree! tp == .node [
.node [.root [10, 13], .root [28, 40]],
.node [.root [46, 67], .root [64, 94]]
]

#guard
let tp := BV8
let x := (Element.arange tp 12).reshape! (Shape.mk [2, 1, 2, 3])
let y := (Element.arange tp 6).reshape! (Shape.mk [3, 2])
let z := get! $ matmul tp x y
z.toTree! tp == .node [
.node [.node [.root [10, 13], .root [28, 40]]],
.node [.node [.root [46, 67], .root [64, 94]]]
]

/-! WIP example NKI kernel
"""
Expand Down Expand Up @@ -183,5 +312,7 @@ private def nki_tensor_add_kernel_ (program_id0 program_id1 : Nat) (a_input b_in
let () <- sorry -- store to c_input
.ok ()
-/

end Test
end Ufunc
end Tensor
end TensorLib

0 comments on commit 9f31f0b

Please sign in to comment.