Skip to content

Commit

Permalink
chore: noop tidying
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmcl committed Mar 3, 2025
1 parent e9d7dd0 commit 46fad4e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 31 deletions.
11 changes: 6 additions & 5 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@ end Test

def natProd (shape : List Nat) : Nat := shape.foldl (fun x y => x * y) 1


-- We generally have large tensors, so don't show them by default
instance ByteArrayRepr : Repr ByteArray where
reprPrec x _ :=
if x.size < 100 then x.toList.repr 100 else
s!"ByteArray of size {x.size}"

def _root_.ByteArray.reverse (arr : ByteArray) : ByteArray := ⟨ arr.data.reverse ⟩

/-!
NumPy arrays can be stored in big-endian or little-endian order on disk, regardless
of the architecture of the machine saving the array. Since we read these arrays
Expand Down Expand Up @@ -493,7 +494,7 @@ def BV8.toUInt8 (n : BV8) : UInt8 := UInt8.ofNat n.toFin

def BV8.toByteArray (x : BV8) : ByteArray := [x.toUInt8].toByteArray

def ByteArray.toBV8 (x : ByteArray) (startIndex : Nat) : Err BV8 :=
def _root_.ByteArray.toBV8 (x : ByteArray) (startIndex : Nat) : Err BV8 :=
let n := startIndex
if H7 : n < x.size then
let H0 : n + 0 < x.size := by omega
Expand Down Expand Up @@ -527,7 +528,7 @@ theorem BV16.BytesRoundTrip1 (x0 x1 : BV8) :
unfold BV16.toBytes BV16.ofBytes
bv_decide

def ByteArray.toBV16 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV16 :=
def _root_.ByteArray.toBV16 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV16 :=
let n := startIndex
if H7 : n + 1 < x.size then
let H0 : n + 0 < x.size := by omega
Expand Down Expand Up @@ -580,7 +581,7 @@ theorem BV32.BytesRoundTrip1 (x0 x1 x2 x3 : BV8) :
unfold BV32.toBytes BV32.ofBytes
bv_decide

def ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV32 :=
def _root_.ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV32 :=
let n := startIndex
if H7 : n + 3 < x.size then
let H0 : n + 0 < x.size := by omega
Expand Down Expand Up @@ -653,7 +654,7 @@ theorem BV64.BytesRoundTrip1 (x0 x1 x2 x3 x4 x5 x6 x7 : BV8) :
unfold BV64.toBytes BV64.ofBytes
bv_decide

def ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV64 :=
def _root_.ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV64 :=
let n := startIndex
if H7 : n + 7 < x.size then
let H0 : n + 0 < x.size := by omega
Expand Down
20 changes: 10 additions & 10 deletions TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,17 @@ def uint64 : Dtype := Dtype.mk .uint64 .littleEndian
def float32 : Dtype := Dtype.mk .float32 .littleEndian
def float64 : Dtype := Dtype.mk .float64 .littleEndian

def isInt (dtype : Dtype) : Bool := dtype.name.isInt
def isUint (dtype : Dtype) : Bool := dtype.name.isUint
def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint

def make (name : Name) (order : ByteOrder) : Err Dtype := match order with
| .oneByte => if name.isOneByte then .ok $ mk name order else .error "illegal dtype"
| .littleEndian | .bigEndian => if name.isMultiByte then .ok $ mk name order else .error "illegal dtype"

def byteOrderOk (dtype : Dtype) : Prop := !dtype.name.isMultiByte || (dtype.name.isMultiByte && dtype.order.isMultiByte)
private def byteOrderOk (dtype : Dtype) : Prop := !dtype.name.isMultiByte || (dtype.name.isMultiByte && dtype.order.isMultiByte)

theorem makeOk (name : Name) (order : ByteOrder) : match make name order with
private theorem makeOk (name : Name) (order : ByteOrder) : match make name order with
| .ok dtype => dtype.byteOrderOk
| .error _ => true := by
unfold make byteOrderOk Name.isMultiByte Name.isOneByte
Expand Down Expand Up @@ -186,7 +190,7 @@ private def byteArrayToIntRoundTrip (dtype : Dtype) (n : Int) : Bool :=
#guard int8.byteArrayToIntRoundTrip 127
#guard !int8.byteArrayToIntRoundTrip 255

def byteArrayToFloat (dtype : Dtype) (arr : ByteArray) : Err Float := match dtype.name with
private def byteArrayToFloat (dtype : Dtype) (arr : ByteArray) : Err Float := match dtype.name with
| .float64 =>
if arr.size != 8 then .error "byte size mismatch" else
match dtype.order with
Expand All @@ -204,17 +208,17 @@ def byteArrayToFloat32 (dtype : Dtype) (arr : ByteArray) : Err Float32 := match
| .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible"
| _ => .error "Illegal type conversion"

def byteArrayToFloat! (dtype : Dtype) (arr : ByteArray) : Float := get! $ byteArrayToFloat dtype arr
private def byteArrayToFloat! (dtype : Dtype) (arr : ByteArray) : Float := get! $ byteArrayToFloat dtype arr

def byteArrayOfFloat (dtype : Dtype) (f : Float) : Err ByteArray := match dtype.name with
private def byteArrayOfFloat (dtype : Dtype) (f : Float) : Err ByteArray := match dtype.name with
| .float64 => .ok $ BV64.toByteArray f.toBits.toBitVec dtype.order
| _ => .error "Illegal type conversion"

def byteArrayOfFloat32 (dtype : Dtype) (f : Float32) : Err ByteArray := match dtype.name with
| .float32 => .ok $ BV32.toByteArray f.toBits.toBitVec dtype.order
| _ => .error "Illegal type conversion"

def byteArrayOfFloat! (dtype : Dtype) (f : Float) : ByteArray := get! $ byteArrayOfFloat dtype f
private def byteArrayOfFloat! (dtype : Dtype) (f : Float) : ByteArray := get! $ byteArrayOfFloat dtype f

private def byteArrayToFloatRoundTrip (dtype : Dtype) (f : Float) : Bool :=
let res := do
Expand Down Expand Up @@ -270,10 +274,6 @@ private def unsignedBEByteArrayToNat (arr : Array UInt8) : Nat := unsignedLEByte
#guard unsignedLEByteArrayToNat #[1, 0, 1, 1] == 13
#guard unsignedBEByteArrayToNat #[1, 0, 1, 1] == 11

def isInt (dtype : Dtype) : Bool := dtype.name.isInt
def isUint (dtype : Dtype) : Bool := dtype.name.isUint
def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint

def add (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "add: byte size mismatch" else
match dtype.name with
Expand Down
2 changes: 1 addition & 1 deletion TensorLib/Mgrid.lean
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def mgrid (slices : List Slice) : Err Tensor := do
mgridIter := mgridIter'
if values.length != sliceCount then .error "Invariant failure: value length mismatch"
for (i, v) in (List.range sliceCount).zip values do
let value <- Dtype.byteArrayOfInt Dtype.int64 v
let value <- Dtype.int64.byteArrayOfInt v
arr <- arr.setDimIndex (i :: index) value
return arr

Expand Down
35 changes: 20 additions & 15 deletions TensorLib/Tensor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ to capture that pattern. My guess is that there is we can figure out if we need
a copy by looking at startPosition, shape, and strides.
-/
private def copyAndReshape (arr : Tensor) (shape : Shape) : Err Tensor :=
if arr.shape.count != shape.count then .error "Incompatible shapes" else
if arr.shape.count != shape.count then .error s!"Incompatible shapes: {arr.shape} {shape}" else
let arr := arr.copy
.ok { arr with shape, unitStrides := shape.unitStrides }

Expand All @@ -242,7 +242,7 @@ def copyAndReshape! (arr : Tensor) (shape : Shape) : Tensor :=

def reshape (arr : Tensor) (shape : Shape) : Err Tensor :=
if arr.shape == shape then .ok arr else
if arr.shape.count != shape.count then .error "Incompatible shapes" else
if arr.shape.count != shape.count then .error s!"Incompatible shapes: {arr.shape} {shape}" else
if arr.isTriviallyReshapable
then .ok { arr with shape, unitStrides := shape.unitStrides }
else copyAndReshape arr shape
Expand Down Expand Up @@ -304,6 +304,8 @@ def broadcastTo (arr : Tensor) (shape : Shape) : Err Tensor :=
let strides <- broadcastStrides (arr.shape.val.zip arr.strides) shape
.ok $ Tensor.mk arr.dtype shape arr.data arr.startIndex strides

def broadcastTo! (arr : Tensor) (shape : Shape) : Tensor := get! $ broadcastTo arr shape

def broadcast (arr1 : Tensor) (arr2 : Tensor) : Err (Tensor × Tensor) :=
match Broadcast.broadcast { left := arr1.shape, right := arr2.shape } with
| none => .error "Can't broadcast"
Expand All @@ -322,6 +324,12 @@ def arrayScalarNat (dtype : Dtype) (n : Nat) : Err Tensor := do

def arrayScalarNat! (dtype : Dtype) (n : Nat) : Tensor := get! $ arrayScalarNat dtype n

def arrayScalarInt (dtype : Dtype) (n : Int) : Err Tensor := do
let arr <- dtype.byteArrayOfInt n
arrayScalar dtype arr

def arrayScalarInt! (dtype : Dtype) (n : Int) : Tensor := get! $ arrayScalarInt dtype n

def arange (dtype : Dtype) (n : Nat) : Err Tensor := do
let size := dtype.itemsize
let mut data := ByteArray.mkEmpty (n * size)
Expand Down Expand Up @@ -584,14 +592,13 @@ section Test
open TensorLib.Tensor.Format.Tree

#guard
let arr := get! (arrayScalarNat Dtype.uint8 5)
let t := get! arr.toNatTree
let arr := arrayScalarNat! Dtype.uint8 5
let t := arr.toNatTree!
t == .root [5]

#guard
let arr := get! $ arange Dtype.uint16 10
let arr := get! $ arr.reshape (Shape.mk [2, 5])
let t := get! $ arr.toNatTree
let arr := (arange! Dtype.uint16 10).reshape! (Shape.mk [2, 5])
let t := arr.toNatTree!
t == node [root [0, 1, 2, 3, 4], root [5, 6, 7, 8, 9]]

#guard (zeros Dtype.float64 $ Shape.mk [2, 2]).nbytes == 2 * 2 * 8
Expand All @@ -600,19 +607,17 @@ open TensorLib.Tensor.Format.Tree
#guard (ones Dtype.float64 $ Shape.mk [2, 2]).data.toList.count 1 == 2 * 2

#guard
let t1 := get! $ arange Dtype.uint8 6
let t2 := get! $ t1.reshape (Shape.mk [2, 3])
let t3 := get! $ t2.broadcastTo (Shape.mk [2, 2, 3])
let tree := get! $ t3.toNatTree
let t1 := (arange! Dtype.uint8 6).reshape! (Shape.mk [2, 3])
let t2 := t1.broadcastTo! (Shape.mk [2, 2, 3])
let tree := t2.toNatTree!
let n1 := node [ root [0, 1, 2], root [3, 4, 5] ]
let tree' := node [ n1, n1 ]
tree == tree'

#guard
let t1 := get! $ arange Dtype.uint8 8
let t2 := get! $ t1.reshape (Shape.mk [2, 1, 1, 4])
let t3 := get! $ t2.broadcastTo (Shape.mk [2, 3, 3, 4])
let tree := get! $ t3.toNatTree
let t1 := (arange! Dtype.uint8 8).reshape! (Shape.mk [2, 1, 1, 4])
let t2 := t1.broadcastTo! (Shape.mk [2, 3, 3, 4])
let tree := t2.toNatTree!
let r1 := root [0, 1, 2, 3]
let r2 := root [4, 5, 6, 7]
let n1 := node [ r1, r1, r1 ]
Expand Down

0 comments on commit 46fad4e

Please sign in to comment.