Skip to content

Commit

Permalink
chore: Remove Element type class. Use Dtype instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmcl committed Feb 26, 2025
1 parent e918294 commit ef0fca5
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 485 deletions.
13 changes: 8 additions & 5 deletions Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
import Init.System.IO
import Cli
import TensorLib
import TensorLib.Test

open Cli
open TensorLib

def format (p : Parsed) : IO UInt32 := do
let shape : Shape := Shape.mk (p.variableArgsAs! Nat).toList
IO.println s!"Got shape {shape}"
let range := Tensor.Element.arange BV16 shape.count
let range := Tensor.arange! Dtype.uint16 shape.count
let v := range.reshape! shape
let s := v.format BV16
IO.println s
IO.println v.toNatTree!.format!
return 0

def formatCmd := `[Cli|
Expand Down Expand Up @@ -52,8 +52,11 @@ def parseNpyCmd := `[Cli|
]

def runTests (_ : Parsed) : IO UInt32 := do
-- Just pytest for now, but add Lean tests here as well
-- pytest will exit nonzero on it's own, so we don't need to check exit code
IO.println "Running Lean tests..."
let t0 <- Test.runAllTests
if !t0 then do
IO.println "Lean tests failed"
return 1
IO.println "Running PyTest..."
let output <- IO.Process.output { cmd := "pytest" }
IO.println s!"stdout: {output.stdout}"
Expand Down
3 changes: 3 additions & 0 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ instance [BEq a] : BEq (Err a) where
| .error x, .error y => x == y
| _, _ => false

instance : BEq ByteArray where
beq x y := x.data == y.data

def get! [Inhabited a] (x : Err a) : a := match x with
| .error msg => impossible msg
| .ok x => x
Expand Down
5 changes: 5 additions & 0 deletions TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def itemsize (x : Name) : Nat := match x with
end Name
end Dtype

/-
We have a fixed number of Dtype values, defined at the bottom of the namespace.
Constructor is private so we can't make any more by accident.
-/
structure Dtype where
private mk ::
name : Dtype.Name
order : ByteOrder
deriving BEq, Repr, Inhabited
Expand Down
236 changes: 111 additions & 125 deletions TensorLib/Index.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin

import TensorLib.Broadcast
import TensorLib.Common
import TensorLib.Dtype
import TensorLib.Npy
import TensorLib.Slice
import TensorLib.Tensor
Expand Down Expand Up @@ -272,7 +273,7 @@ We currently handle
1. Non-slice indices (e.g. arr[1][7][-2])
2. TODO
-/
def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do
def apply (index : NumpyBasic) (arr : Tensor) : Err Tensor := do
let oldShape := arr.shape
let mut startIndex := arr.startIndex
let mut needsCopy := false
Expand All @@ -286,15 +287,16 @@ def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do
needsCopy := true
break
if needsCopy then do
let res <- applyWithCopy index arr
return (res, true)
applyWithCopy index arr
else
let res := { arr with
shape := newShape,
startIndex,
unitStrides := arr.unitStrides.drop basic.length
}
return (res, false)
return res

def apply! (index : NumpyBasic) (arr : Tensor) : Tensor := (get! (apply index arr))

/-
`arr[index] = v`
Expand All @@ -314,8 +316,7 @@ is ok but
wouldn't make any sense, even though the shapes (1, 3) and (2, 3) are broadcastable.
-/
def assign (a : Type) [typ : Tensor.Element a] (arr : Tensor) (index : NumpyBasic) (v : Tensor) : Err Tensor := do
if arr.dtype != v.dtype || typ.dtype != arr.dtype then .error "Type mismatch" else
def assign (arr : Tensor) (index : NumpyBasic) (v : Tensor) : Err Tensor := do
let (basic, shape) <- toBasic index arr.shape
let v <- v.broadcastTo shape
let aIter := BasicIter.make basic
Expand All @@ -324,10 +325,12 @@ def assign (a : Type) [typ : Tensor.Element a] (arr : Tensor) (index : NumpyBasi
for aIndex in aIter do
let (vIndex, vIter') := vIter.next
vIter := vIter'
let vVal <- typ.getDimIndex v vIndex
res <- Tensor.Element.setDimIndex res aIndex vVal
let vVal <- v.getDimIndex vIndex
res <- res.setDimIndex aIndex vVal
return res

def assign! (arr : Tensor) (index : NumpyBasic) (v : Tensor) : Tensor := get! $ assign arr index v

/-
For advanced indexing, the all-multidimensional-array case is relatively easy;
broadcast all arguments to the same shape, then select the elements of the original
Expand Down Expand Up @@ -389,6 +392,13 @@ def apply (indexTensors : List Tensor) (arr : Tensor) : Err Tensor := do
res <- res.setByteArrayAtDimIndex outDimIndex bytes
return res

def apply! (indexTensors : List Tensor) (arr : Tensor) : Tensor := get! $ apply indexTensors arr

end Advanced

section Test
open Tensor.Format.Tree

/-
0 1 2
3 4 5
Expand All @@ -403,120 +413,111 @@ def apply (indexTensors : List Tensor) (arr : Tensor) : Err Tensor := do
24 25 26
-/
#guard
let ind0 := (Tensor.Element.ofList Int8 [1, 2, 0, 0]).reshape! (Shape.mk [2, 2])
let ind1 := (Tensor.Element.ofList Int8 [2, -2, 0, 1]).reshape! (Shape.mk [2, 2])
let ind2 := (Tensor.Element.ofList Int8 [1, 1, -1, -1]).reshape! (Shape.mk [2, 2])
let typ := BV16
let arr := (Tensor.Element.arange typ 27).reshape! (Shape.mk [3, 3, 3])
let res := get! $ apply [ind0, ind1, ind2] arr
let tree := get! $ res.toTree typ
let tp := Dtype.int8
let ind0 := (Tensor.ofIntList! tp [1, 2, 0, 0]).reshape! (Shape.mk [2, 2])
let ind1 := (Tensor.ofIntList! tp [2, -2, 0, 1]).reshape! (Shape.mk [2, 2])
let ind2 := (Tensor.ofIntList! tp [1, 1, -1, -1]).reshape! (Shape.mk [2, 2])
let arr := (Tensor.arange! Dtype.uint8 27).reshape! (Shape.mk [3, 3, 3])
let res := Advanced.apply! [ind0, ind1, ind2] arr
let tree := res.toNatTree!
tree == Tensor.Format.Tree.node [.root [16, 22], .root [2, 5]]

end Advanced

section Test
open Tensor
open Tensor.Format.Tree

#guard
let tp := BV8
let tensor := Element.arange tp 10
let tensor := tensor.reshape! $ Shape.mk [2, 5]
let tp := Dtype.int8
let tensor := (Tensor.arange! tp 10).reshape! (Shape.mk [2, 5])
let index := [.int 1]
let res := get! $ applyWithCopy index tensor
let tree := get! $ res.toTree tp
let tree := res.toNatTree!
let tree' := .root [5, 6, 7, 8, 9]
tree == tree'

#guard
let tp := BV8
let tensor := Element.arange tp 10
let tensor := tensor.reshape! $ Shape.mk [2, 5]
let tp := Dtype.int8
let tensor := (Tensor.arange! tp 10).reshape! $ Shape.mk [2, 5]
let index := [.int 1]
-- Bug in #guard keeps me from using `let (arr, copied) := ...` here
let ac := get! $ apply index tensor
let arr := ac.fst
let copied := ac.snd
let arr := apply! index tensor
let tree' := .root [5, 6, 7, 8, 9]
!copied && (get! $ arr.toTree tp) == tree'

#guard let tp := BV8
let tensor := Element.arange tp 20
let tensor := tensor.reshape! $ Shape.mk [2, 2, 5]
let index := [.int 1, .int 1, .int 4]
-- Bug in #guard keeps me from using `let (arr, copied) := ...` here
let ac := get! $ apply index tensor
let arr := ac.fst
let copied := ac.snd
let tree := get! $ arr.toTree tp
let tree' := .root [19]
!copied && tree == tree'

#guard let tp := BV8
let tensor1 := (Element.arange tp 20).reshape! (Shape.mk [2, 2, 5])
let index := [.int 1, .int 1, .int 4]
let tensor2 := Element.arrayScalar tp 255
let res := get! $ assign tp tensor1 index tensor2
let tree := get! $ res.toTree tp
let tree' := node [
node [ root [0, 1, 2, 3, 4], root [5, 6, 7, 8, 9] ],
node [ root [10, 11, 12, 13, 14], root [15, 16, 17, 18, 255] ],
]
tree == tree'

#guard let tp := BV8
let tensor1 := (Element.arange tp 20).reshape! (Shape.mk [2, 2, 5])
let index := [.int 1, .int 1, .newaxis]
let tensor2 := Element.ofList tp [50, 60, 70, 80, 90]
let res := get! $ assign tp tensor1 index tensor2
let tree := get! $ res.toTree tp
let tree' := node [
node [ root [0, 1, 2, 3, 4], root [5, 6, 7, 8, 9] ],
node [ root [10, 11, 12, 13, 14], root [50, 60, 70, 80, 90] ],
]
tree == tree'

#guard let tp := BV8
let tensor1 := (Element.arange tp 20).reshape! (Shape.mk [2, 2, 5])
let index := [.int 1, .int 1, .slice (Slice.ofStartStop 1 4)]
let tensor2 := Element.ofList tp [50, 60, 70]
let res := get! $ assign tp tensor1 index tensor2
let tree := get! $ res.toTree tp
let tree' := node [
node [ root [0, 1, 2, 3, 4], root [5, 6, 7, 8, 9] ],
node [ root [10, 11, 12, 13, 14], root [15, 50, 60, 70, 19] ],
]
tree == tree'

#guard let tp := BV8
let tensor1 := (Element.arange tp 20).reshape! (Shape.mk [4, 5])
let index := [.slice (Slice.ofStartStop 1 3), .slice (Slice.ofStartStop 1 4)]
let tensor2 := (Element.ofList tp [40, 50, 60, 70, 80, 90]).reshape! (Shape.mk [2, 3])
let res := get! $ assign tp tensor1 index tensor2
let tree := get! $ res.toTree tp
let tree' := node [
root [0, 1, 2, 3, 4],
root [5, 40, 50, 60, 9],
root [10, 70, 80, 90, 14],
root [15, 16, 17, 18, 19]
]
tree == tree'

#guard let tp := BV8
let tensor1 := (Element.arange tp 20).reshape! (Shape.mk [4, 5])
let index := [NumpyItem.slice (Slice.ofStartStop 1 3), .slice (Slice.ofStartStop 1 4)]
let tensor2 := Element.ofList tp [40, 50, 60] -- tensor2 should be broadcast to (2, 3)
let res := get! $ assign tp tensor1 index tensor2
let tree := get! $ res.toTree tp
let tree' := node [
root [0, 1, 2, 3, 4],
root [5, 40, 50, 60, 9],
root [10, 40, 50, 60, 14],
root [15, 16, 17, 18, 19]
]
tree == tree'

-- Testing
arr.toNatTree! == tree'

#guard
let tp := Dtype.int8
let tensor := (Tensor.arange! tp 20).reshape! $ Shape.mk [2, 2, 5]
let index := [.int 1, .int 1, .int 4]
let arr := apply! index tensor
let tree := arr.toIntTree!
let tree' := .root [19]
tree == tree'

#guard
let tp := Dtype.uint8
let tensor1 := (Tensor.arange! tp 20).reshape! (Shape.mk [2, 2, 5])
let index := [.int 1, .int 1, .int 4]
let tensor2 := Tensor.arrayScalarNat! tp 255
let res := assign! tensor1 index tensor2
let tree := res.toNatTree!
let tree' := node [
node [ root [0, 1, 2, 3, 4], root [5, 6, 7, 8, 9] ],
node [ root [10, 11, 12, 13, 14], root [15, 16, 17, 18, 255] ],
]
tree == tree'

#guard
let tp := Dtype.uint8
let tensor1 := (Tensor.arange! tp 20).reshape! (Shape.mk [2, 2, 5])
let index := [.int 1, .int 1, .newaxis]
let tensor2 := Tensor.ofNatList! tp [50, 60, 70, 80, 90]
let res := assign! tensor1 index tensor2
let tree := res.toNatTree!
let tree' := node [
node [ root [0, 1, 2, 3, 4], root [5, 6, 7, 8, 9] ],
node [ root [10, 11, 12, 13, 14], root [50, 60, 70, 80, 90] ],
]
tree == tree'

#guard
let tp := Dtype.uint8
let tensor1 := (Tensor.arange! tp 20).reshape! (Shape.mk [2, 2, 5])
let index := [.int 1, .int 1, .slice (Slice.ofStartStop 1 4)]
let tensor2 := Tensor.ofNatList! tp [50, 60, 70]
let res := assign! tensor1 index tensor2
let tree := res.toNatTree!
let tree' := node [
node [ root [0, 1, 2, 3, 4], root [5, 6, 7, 8, 9] ],
node [ root [10, 11, 12, 13, 14], root [15, 50, 60, 70, 19] ],
]
tree == tree'

#guard
let tp := Dtype.uint8
let tensor1 := (Tensor.arange! tp 20).reshape! (Shape.mk [4, 5])
let index := [.slice (Slice.ofStartStop 1 3), .slice (Slice.ofStartStop 1 4)]
let tensor2 := (Tensor.ofNatList! tp [40, 50, 60, 70, 80, 90]).reshape! (Shape.mk [2, 3])
let res := get! $ assign tensor1 index tensor2
let tree := res.toNatTree!
let tree' := node [
root [0, 1, 2, 3, 4],
root [5, 40, 50, 60, 9],
root [10, 70, 80, 90, 14],
root [15, 16, 17, 18, 19]
]
tree == tree'

#guard
let tp := Dtype.uint8
let tensor1 := (Tensor.arange! tp 20).reshape! (Shape.mk [4, 5])
let index := [NumpyItem.slice (Slice.ofStartStop 1 3), .slice (Slice.ofStartStop 1 4)]
let tensor2 := Tensor.ofNatList! tp [40, 50, 60] -- tensor2 should be broadcast to (2, 3)
let res := assign! tensor1 index tensor2
let tree := res.toNatTree!
let tree' := node [
root [0, 1, 2, 3, 4],
root [5, 40, 50, 60, 9],
root [10, 40, 50, 60, 14],
root [15, 16, 17, 18, 19]
]
tree == tree'

private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (List (List Nat)) := do
let shape := Shape.mk dims
let (basic, _) <- (toBasic basic shape).toOption
Expand Down Expand Up @@ -544,21 +545,6 @@ private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (Li
#guard numpyBasicToList [2, 2] [.slice (Slice.make! .none .none (.some (-1))), .slice Slice.all] == some [[1, 0], [1, 1], [0, 0], [0, 1]]
#guard numpyBasicToList [4, 2] [.slice (Slice.make! .none .none (.some (-2))), .slice Slice.all] == some [[3, 0], [3, 1], [1, 0], [1, 1]]

-- Commented for easier debugging. Remove some day
-- #eval do
-- let shape := [4, 2]
-- let basic := get! $ toBasic [.slice (Slice.build! .none .none (.some (-2))), .slice Slice.all] shape
-- let iter0 := (get! $ make shape basic)
-- let (ns0, iter1) <- iter0.next
-- let (ns1, iter2) <- iter1.next
-- let (ns2, iter3) <- iter2.next
-- -- let (ns4, iter4) <- iter3.next
-- -- let (ns5, iter5) <- iter4.next
-- -- let (ns6, iter6) <- iter5.next
-- -- let (ns7, iter7) <- iter6.next
-- -- let (ns8, iter8) <- iter7.next
-- -- let (ns9, iter9) <- iter8.next
-- return (basic, iter0, ns0, iter1, ns1, iter2, ns2, iter3) -- , ns4, iter4) -- , ns5, iter5, ns6, iter6, ns7, iter7, ns8, iter8, ns9, iter9)
end Test

end Index
Expand Down
Loading

0 comments on commit ef0fca5

Please sign in to comment.