diff --git a/TensorLib/Common.lean b/TensorLib/Common.lean index b609db2..0f2c33e 100644 --- a/TensorLib/Common.lean +++ b/TensorLib/Common.lean @@ -74,97 +74,49 @@ instance ByteArrayRepr : Repr ByteArray where def _root_.ByteArray.reverse (arr : ByteArray) : ByteArray := ⟨ arr.data.reverse ⟩ -/-! -NumPy arrays can be stored in big-endian or little-endian order on disk, regardless -of the architecture of the machine saving the array. Since we read these arrays -into memory at certain data types, for multi-byte data types we need to know the -endian-ness. --/ -inductive ByteOrder where -| oneByte -| littleEndian -| bigEndian -deriving BEq, Repr, Inhabited - -namespace ByteOrder - -@[simp] -def isMultiByte (x : ByteOrder) : Bool := match x with -| .oneByte => false -| .littleEndian | .bigEndian => true - -def bytesToNat (order : ByteOrder) (bytes : ByteArray) : Nat := Id.run do +def _root_.ByteArray.toNat (arr : ByteArray) : Nat := Id.run do let mut n : Nat := 0 - let nbytes := bytes.size + let nbytes := arr.size for i in [0:nbytes] do - let v : UInt8 := bytes.get! i - let p := match order with - | .oneByte => 0 -- nbytes = 1 - | .littleEndian => i - | .bigEndian => nbytes - 1 - i - n := n + Pow.pow 2 (8 * p) * v.toNat + let v : UInt8 := arr.get! i + n := n + Pow.pow 2 (8 * i) * v.toNat return n -#guard bytesToNat .littleEndian (ByteArray.mk #[1, 1]) == 257 -#guard bytesToNat .bigEndian (ByteArray.mk #[1, 1]) == 257 -#guard bytesToNat .littleEndian (ByteArray.mk #[0, 1]) == 256 -#guard bytesToNat .bigEndian (ByteArray.mk #[0, 1]) == 1 -#guard bytesToNat .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == 65535 -#guard bytesToNat .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == 65535 -#guard bytesToNat .bigEndian (ByteArray.mk #[0x80, 0]) == 32768 -#guard bytesToNat .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80 +#guard (ByteArray.mk #[1, 1]).toNat == 257 +#guard (ByteArray.mk #[0, 1]).toNat == 256 +#guard (ByteArray.mk #[0xFF, 0xFF]).toNat == 65535 +#guard (ByteArray.mk #[0, 0x80]).toNat == 32768 +#guard (ByteArray.mk #[0x80, 0]).toNat == 0x80 -def bytesToInt (order : ByteOrder) (bytes : ByteArray) : Int := Id.run do +def _root_.ByteArray.toInt (arr : ByteArray) : Int := Id.run do let mut n : Nat := 0 - let nbytes := bytes.size - let signByte := match order with - | .littleEndian => bytes.get! (nbytes - 1) - | .bigEndian | oneByte => bytes.get! 0 + let nbytes := arr.size + let signByte := arr.get! (nbytes - 1) let negative := 128 <= signByte for i in [0:nbytes] do - let v : UInt8 := bytes.get! i + let v : UInt8 := arr.get! i let v := if negative then UInt8.complement v else v - let p := match order with - | .oneByte => 0 -- nbytes = 1 - | .littleEndian => i - | .bigEndian => nbytes - 1 - i - n := n + Pow.pow 2 (8 * p) * v.toNat + n := n + Pow.pow 2 (8 * i) * v.toNat return if 128 <= signByte then -(n + 1) else n -#guard bytesToInt .littleEndian (ByteArray.mk #[1, 1]) == 257 -#guard bytesToInt .bigEndian (ByteArray.mk #[1, 1]) == 257 -#guard bytesToInt .littleEndian (ByteArray.mk #[0, 1]) == 256 -#guard bytesToInt .bigEndian (ByteArray.mk #[0, 1]) == 1 -#guard bytesToInt .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == -1 -#guard bytesToInt .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == -1 -#guard bytesToInt .bigEndian (ByteArray.mk #[0x80, 0]) == -32768 -#guard bytesToInt .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80 +#guard (ByteArray.mk #[1, 1]).toInt == 257 +#guard (ByteArray.mk #[0, 1]).toInt == 256 +#guard (ByteArray.mk #[1, 0]).toInt == 1 +#guard (ByteArray.mk #[0xFF, 0xFF]).toInt == -1 +#guard (ByteArray.mk #[0, 0x80]).toInt == -32768 +#guard (ByteArray.mk #[0x80, 0]).toInt == 0x80 -def bitVecToByteArray (order : ByteOrder) (n : Nat) (v : BitVec n) : ByteArray := Id.run do +def bitVecToByteArray (n : Nat) (v : BitVec n) : ByteArray := Id.run do let numBytes := natDivCeil n 8 let mut arr := ByteArray.mkEmpty numBytes - match order with - | .oneByte => - let byte := (v &&& 0xFF).toNat.toUInt8 - return arr.push byte - | .littleEndian => - for i in [0 : numBytes] do - let byte := (v.ushiftRight (i * 8) &&& 0xFF).toNat.toUInt8 - arr := arr.push byte - return arr - | .bigEndian => - for i in [0 : numBytes] do - let byte := (v.ushiftRight ((numBytes - i - 1) * 8) &&& 0xFF).toNat.toUInt8 - arr := arr.push byte - return arr - -#guard (bitVecToByteArray .bigEndian 16 0x0100).toList == [1, 0] -#guard (bitVecToByteArray .littleEndian 16 0x0100).toList == [0, 1] -#guard (bitVecToByteArray .bigEndian 20 0x01000).toList == [0, 16, 0] -#guard (bitVecToByteArray .littleEndian 32 0x1).toList == [1, 0, 0, 0] -#guard (bitVecToByteArray .bigEndian 32 0x1).toList == [0, 0, 0, 1] - -end ByteOrder + for i in [0 : numBytes] do + let byte := (v.ushiftRight (i * 8) &&& 0xFF).toNat.toUInt8 + arr := arr.push byte + return arr + +#guard (bitVecToByteArray 16 0x0100).toList == [0, 1] +#guard (bitVecToByteArray 20 0x01000).toList == [0, 16, 0] +#guard (bitVecToByteArray 32 0x1).toList == [1, 0, 0, 0] /-! The strides are how many bytes you need to skip to get to the next element in that @@ -492,7 +444,7 @@ def BV8.ofNat (i : Nat) : BV8 := i.toUInt8.toBitVec def _root_.UInt8.toBV8 (n : UInt8) : BV8 := BitVec.ofFin n.val def BV8.toUInt8 (n : BV8) : UInt8 := UInt8.ofNat n.toFin -def BV8.toByteArray (x : BV8) : ByteArray := [x.toUInt8].toByteArray +def BV8.toByteArray (x : BV8) : ByteArray := bitVecToByteArray 8 x def _root_.ByteArray.toBV8 (x : ByteArray) (startIndex : Nat) : Err BV8 := let n := startIndex @@ -504,6 +456,9 @@ def _root_.ByteArray.toBV8 (x : ByteArray) (startIndex : Nat) : Err BV8 := abbrev BV16 := BitVec 16 +def BV16.toByteArray (x : BV16) : ByteArray := bitVecToByteArray 16 x + + def BV16.toBytes (n : BV16) : BV8 × BV8 := let n0 := (n >>> 0o00 &&& 0xFF).truncate 8 let n1 := (n >>> 0o10 &&& 0xFF).truncate 8 @@ -528,29 +483,20 @@ theorem BV16.BytesRoundTrip1 (x0 x1 : BV8) : unfold BV16.toBytes BV16.ofBytes bv_decide -def _root_.ByteArray.toBV16 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV16 := +def _root_.ByteArray.toBV16 (x : ByteArray) (startIndex : Nat) : Err BV16 := let n := startIndex if H7 : n + 1 < x.size then let H0 : n + 0 < x.size := by omega let H1 : n + 1 < x.size := by omega let x0 := x.get (Fin.mk _ H0) let x1 := x.get (Fin.mk _ H1) - match order with - | .oneByte => .error "illegal byte order" - | .littleEndian => .ok (BV16.ofBytes x0.toBV8 x1.toBV8) - | .bigEndian => .ok (BV16.ofBytes x1.toBV8 x0.toBV8) + .ok (BV16.ofBytes x0.toBV8 x1.toBV8) else .error s!"Index out of range: {n}" -def BV16.toByteArray (x : BV16) (ord : ByteOrder) : ByteArray := - let (x0, x1) := x.toBytes - let arr := match ord with - | .littleEndian => [x0, x1] - | .bigEndian => [x1, x0] - | .oneByte => [] -- Avoid Err for now - (arr.map BV8.toUInt8).toByteArray - abbrev BV32 := BitVec 32 +def BV32.toByteArray (x : BV32) : ByteArray := bitVecToByteArray 32 x + def BV32.toBytes (n : BV32) : BV8 × BV8 × BV8 × BV8 := let n0 := (n >>> 0o00 &&& 0xFF).truncate 8 let n1 := (n >>> 0o10 &&& 0xFF).truncate 8 @@ -581,7 +527,7 @@ theorem BV32.BytesRoundTrip1 (x0 x1 x2 x3 : BV8) : unfold BV32.toBytes BV32.ofBytes bv_decide -def _root_.ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV32 := +def _root_.ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) : Err BV32 := let n := startIndex if H7 : n + 3 < x.size then let H0 : n + 0 < x.size := by omega @@ -592,22 +538,13 @@ def _root_.ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) (order : ByteOrde let x1 := x.get (Fin.mk _ H1) let x2 := x.get (Fin.mk _ H2) let x3 := x.get (Fin.mk _ H3) - match order with - | .oneByte => .error "illegal byte order" - | .littleEndian => .ok (BV32.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8) - | .bigEndian => .ok (BV32.ofBytes x3.toBV8 x2.toBV8 x1.toBV8 x0.toBV8) + .ok (BV32.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8) else .error s!"Index out of range: {n}" -def BV32.toByteArray (x : BV32) (ord : ByteOrder) : ByteArray := - let (x0, x1, x2, x3) := x.toBytes - let arr := match ord with - | .littleEndian => [x0, x1, x2, x3] - | .bigEndian => [x3, x2, x1, x0] - | .oneByte => [] - (arr.map BV8.toUInt8).toByteArray - abbrev BV64 := BitVec 64 +def BV64.toByteArray (x : BV64) : ByteArray := bitVecToByteArray 64 x + def BV64.ofNat (i : Nat) : BV64 := i.toUInt64.toBitVec def BV64.ofInt (i : Int) : BV64 := i.toInt64.toBitVec @@ -654,7 +591,7 @@ theorem BV64.BytesRoundTrip1 (x0 x1 x2 x3 x4 x5 x6 x7 : BV8) : unfold BV64.toBytes BV64.ofBytes bv_decide -def _root_.ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV64 := +def _root_.ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) : Err BV64 := let n := startIndex if H7 : n + 7 < x.size then let H0 : n + 0 < x.size := by omega @@ -672,20 +609,9 @@ def _root_.ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) (order : ByteOrde let x5 := x.get (Fin.mk _ H5) let x6 := x.get (Fin.mk _ H6) let x7 := x.get (Fin.mk _ H7) - match order with - | .oneByte => .error "illegal byte order" - | .littleEndian => .ok (BV64.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8 x4.toBV8 x5.toBV8 x6.toBV8 x7.toBV8) - | .bigEndian => .ok (BV64.ofBytes x7.toBV8 x6.toBV8 x5.toBV8 x4.toBV8 x3.toBV8 x2.toBV8 x1.toBV8 x0.toBV8) + .ok (BV64.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8 x4.toBV8 x5.toBV8 x6.toBV8 x7.toBV8) else .error s!"Index out of range: {n}" -def BV64.toByteArray (x : BV64) (ord : ByteOrder) : ByteArray := - let (x0, x1, x2, x3, x4, x5, x6, x7) := x.toBytes - let arr := match ord with - | .littleEndian => [x0, x1, x2, x3, x4, x5, x6, x7] - | .bigEndian => [x7, x6, x5, x4, x3, x2, x1, x0] - | .oneByte => [] - (arr.map BV8.toUInt8).toByteArray - /- The largest Nat such that it and every smaller Nat can be represented exactly in a 64-bit IEEE-754 float https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER @@ -706,10 +632,9 @@ def _root_.Int.toFloat32 (n : Int) : Float32 := match n with | Int.ofNat n => Float32.ofNat n | Int.negSucc n => Float32.neg (Float32.ofNat (Nat.succ n)) -def _root_.Float.toLEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.littleEndian -def _root_.Float.toBEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.bigEndian -def _root_.Float32.toLEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.littleEndian -def _root_.Float32.toBEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.bigEndian + +def _root_.Float32.toLEByteArray (f : Float32) : ByteArray := bitVecToByteArray 32 f.toBits.toBitVec +def _root_.Float.toLEByteArray (f : Float) : ByteArray := bitVecToByteArray 64 f.toBits.toBitVec /-- Interpret a `ByteArray` of size 4 as a little-endian `UInt32`. Missing from Lean stdlib. -/ def _root_.ByteArray.toUInt32LE! (bs : ByteArray) : UInt32 := @@ -754,8 +679,8 @@ def _root_.Float.toInt (f : Float) : Int := #guard ( let n : BV64 := 0x3FFAB851EB851EB8 do - let arr := n.toByteArray .littleEndian - let n' <- ByteArray.toBV64 arr 0 .littleEndian + let arr := n.toByteArray + let n' <- ByteArray.toBV64 arr 0 return n == n') == .ok true end TensorLib diff --git a/TensorLib/Dtype.lean b/TensorLib/Dtype.lean index 5bd8d5a..4563609 100644 --- a/TensorLib/Dtype.lean +++ b/TensorLib/Dtype.lean @@ -7,10 +7,9 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin import TensorLib.Common namespace TensorLib -namespace Dtype /-! The subset of types NumPy supports that we care about -/ -inductive Name where +inductive Dtype where | bool | int8 | int16 @@ -24,10 +23,10 @@ inductive Name where | float64 deriving BEq, Repr, Inhabited -namespace Name +namespace Dtype -- Should match the NumPy name of the dtype. We use toString to generate NumPy test code. -instance : ToString Name where +instance : ToString Dtype where toString | .bool => "bool" | int8 => "int8" @@ -41,24 +40,24 @@ instance : ToString Name where | float32 => "float32" | float64 => "float64" -def isOneByte (x : Name) : Bool := match x with +def isOneByte (x : Dtype) : Bool := match x with | bool | int8 | uint8 => true | _ => false -def isMultiByte (x : Name) : Bool := ! x.isOneByte +def isMultiByte (x : Dtype) : Bool := ! x.isOneByte -def isInt (x : Name) : Bool := match x with +def isInt (x : Dtype) : Bool := match x with | int8 | int16 | int32 | int64 => true | _ => false -def isUint (x : Name) : Bool := match x with +def isUint (x : Dtype) : Bool := match x with | uint8 | uint16 | uint32 | uint64 => true | _ => false -def isIntLike (x : Name) : Bool := x.isInt || x.isUint +def isIntLike (x : Dtype) : Bool := x.isInt || x.isUint --! Number of bytes used by each element of the given dtype -def itemsize (x : Name) : Nat := match x with +def itemsize (x : Dtype) : Nat := match x with | float64 | int64 | uint64 => 8 | float32 | int32 | uint32 => 4 | int16 | uint16 => 2 @@ -86,7 +85,7 @@ OverflowError: Python integer 128 out of bounds for int8 Float types have named safe nat upper bounds. -/ -private def maxSafeNat : Name -> Option Nat +private def maxSafeNat : Dtype -> Option Nat | .bool => none | .uint8 => some 0xFF | .int8 => some 0x7F @@ -99,14 +98,14 @@ private def maxSafeNat : Name -> Option Nat | .float32 => maxSafeNatForFloat32 | .float64 => maxSafeNatForFloat -private def canCastFromNat (dtype : Name) (n : Nat) : Bool := n <= dtype.maxSafeNat.getD n +private def canCastFromNat (dtype : Dtype) (n : Nat) : Bool := n <= dtype.maxSafeNat.getD n /- NumPy doesn't allow casts from negative numbers to uint types, even if they fit. # np.array(-0x1, dtype='uint8') OverflowError: Python integer -1 out of bounds for uint8 -/ -private def minSafeInt : Name -> Option Int +private def minSafeInt : Dtype -> Option Int | .bool => none | .uint8 | .uint16 | .uint32 | .uint64 => some 0 | .int8 => some (-0x80) @@ -116,93 +115,35 @@ private def minSafeInt : Name -> Option Int | .float32 => some (-maxSafeNatForFloat32) | .float64 => some (-maxSafeNatForFloat) -private def canCastFromInt (dtype : Name) (n : Int) : Bool := +private def canCastFromInt (dtype : Dtype) (n : Int) : Bool := if n < 0 then dtype.minSafeInt.getD n <= n else n <= dtype.maxSafeNat.getD n.toNat -end Name -end Dtype - -/- -We have a fixed number of Dtype values, defined at the bottom of the namespace. -Constructor is private so we can't make any more by accident. --/ -structure Dtype where - private mk :: - name : Dtype.Name - order : ByteOrder -deriving BEq, Repr, Inhabited - -namespace Dtype - -def bool : Dtype := Dtype.mk .bool .oneByte -def int8 : Dtype := Dtype.mk .int8 .oneByte -def int16 : Dtype := Dtype.mk .int16 .littleEndian -def int32 : Dtype := Dtype.mk .int32 .littleEndian -def int64 : Dtype := Dtype.mk .int64 .littleEndian -def uint8 : Dtype := Dtype.mk .uint8 .oneByte -def uint16 : Dtype := Dtype.mk .uint16 .littleEndian -def uint32 : Dtype := Dtype.mk .uint32 .littleEndian -def uint64 : Dtype := Dtype.mk .uint64 .littleEndian -def float32 : Dtype := Dtype.mk .float32 .littleEndian -def float64 : Dtype := Dtype.mk .float64 .littleEndian - -def isInt (dtype : Dtype) : Bool := dtype.name.isInt -def isUint (dtype : Dtype) : Bool := dtype.name.isUint -def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint - -def make (name : Name) (order : ByteOrder) : Err Dtype := match order with -| .oneByte => if name.isOneByte then .ok $ mk name order else .error "illegal dtype" -| .littleEndian | .bigEndian => if name.isMultiByte then .ok $ mk name order else .error "illegal dtype" - -private def byteOrderOk (dtype : Dtype) : Prop := !dtype.name.isMultiByte || (dtype.name.isMultiByte && dtype.order.isMultiByte) - -private theorem makeOk (name : Name) (order : ByteOrder) : match make name order with -| .ok dtype => dtype.byteOrderOk -| .error _ => true := by - unfold make byteOrderOk Name.isMultiByte Name.isOneByte - cases name <;> cases order <;> simp - -def itemsize (dtype : Dtype) := dtype.name.itemsize - def sizedStrides (dtype : Dtype) (s : Shape) : Strides := List.map (fun x => x * dtype.itemsize) s.unitStrides -private def byteArrayOfNatOverflow (dtype : Dtype) (n : Nat) : ByteArray := match dtype.name with +private def byteArrayOfNatOverflow (dtype : Dtype) (n : Nat) : ByteArray := match dtype with | .bool => (BV8.ofNat $ if n == 0 then 0 else 1).toByteArray | .uint8 => (BV8.ofNat n).toByteArray | .int8 => [(Int8.ofNat n).toUInt8].toByteArray -| .uint16 => BV16.toByteArray n.toUInt16.toBitVec dtype.order -| .int16 => BV16.toByteArray n.toInt16.toBitVec dtype.order -| .uint32 => BV32.toByteArray n.toUInt32.toBitVec dtype.order -| .int32 => BV32.toByteArray n.toInt32.toBitVec dtype.order -| .uint64 => BV64.toByteArray n.toUInt64.toBitVec dtype.order -| .int64 => BV64.toByteArray n.toInt64.toBitVec dtype.order -| .float32 => match dtype.order with - | .littleEndian => n.toFloat32.toLEByteArray - | .bigEndian => n.toFloat32.toBEByteArray - | .oneByte => impossible -- implies an illegal dtype, which should be impossible -| .float64 => match dtype.order with - | .littleEndian => n.toFloat.toLEByteArray - | .bigEndian => n.toFloat.toBEByteArray - | .oneByte => impossible +| .uint16 => BV16.toByteArray n.toUInt16.toBitVec +| .int16 => BV16.toByteArray n.toInt16.toBitVec +| .uint32 => BV32.toByteArray n.toUInt32.toBitVec +| .int32 => BV32.toByteArray n.toInt32.toBitVec +| .uint64 => BV64.toByteArray n.toUInt64.toBitVec +| .int64 => BV64.toByteArray n.toInt64.toBitVec +| .float32 => n.toFloat32.toLEByteArray +| .float64 => n.toFloat.toLEByteArray def byteArrayOfNat (dtype : Dtype) (n : Nat) : Err ByteArray := - let name := dtype.name - if name.canCastFromNat n then .ok (dtype.byteArrayOfNatOverflow n) - else .error s!"Nat {n} out of bounds for {name}" + if dtype.canCastFromNat n then .ok (dtype.byteArrayOfNatOverflow n) + else .error s!"Nat {n} out of bounds for {dtype}" def byteArrayOfNat! (dtype : Dtype) (n : Nat) : ByteArray := get! $ byteArrayOfNat dtype n -def byteArrayToNat (dtype : Dtype) (arr : ByteArray) : Err Nat := - if dtype.itemsize != arr.size then .error "byte size mismatch" - else .ok $ dtype.order.bytesToNat arr - -def byteArrayToNat! (dtype : Dtype) (arr : ByteArray) : Nat := get! $ byteArrayToNat dtype arr - private def byteArrayToNatRoundTrip (dtype : Dtype) (n : Nat) : Bool := let res := do let arr <- dtype.byteArrayOfNat n - let n' <- dtype.byteArrayToNat arr + let n' := arr.toNat return n == n' res.toOption.getD false @@ -211,38 +152,25 @@ private def byteArrayToNatRoundTrip (dtype : Dtype) (n : Nat) : Bool := #guard uint8.byteArrayToNatRoundTrip 255 #guard !uint8.byteArrayToNatRoundTrip 256 -private def byteArrayOfIntOverflow (dtype : Dtype) (n : Int) : ByteArray := match dtype.name with +private def byteArrayOfIntOverflow (dtype : Dtype) (n : Int) : ByteArray := match dtype with | .bool => (BV8.ofNat $ if n == 0 then 0 else 1).toByteArray | .uint8 | .int8 => [n.toInt8.toUInt8].toByteArray -| .uint16 | .int16 => BV16.toByteArray n.toInt16.toBitVec dtype.order -| .uint32 | .int32 => BV32.toByteArray n.toInt32.toBitVec dtype.order -| .uint64 | .int64 => BV64.toByteArray n.toInt64.toBitVec dtype.order -| .float32 => match dtype.order with - | .littleEndian => n.toFloat.toLEByteArray - | .bigEndian => n.toFloat.toBEByteArray - | .oneByte => impossible -| .float64 => match dtype.order with - | .littleEndian => n.toFloat.toLEByteArray - | .bigEndian => n.toFloat.toBEByteArray - | .oneByte => impossible +| .uint16 | .int16 => BV16.toByteArray n.toInt16.toBitVec +| .uint32 | .int32 => BV32.toByteArray n.toInt32.toBitVec +| .uint64 | .int64 => BV64.toByteArray n.toInt64.toBitVec +| .float32 => n.toFloat.toLEByteArray +| .float64 => n.toFloat.toLEByteArray def byteArrayOfInt (dtype : Dtype) (n : Int) : Err ByteArray := - let name := dtype.name - if name.canCastFromInt n then .ok (dtype.byteArrayOfIntOverflow n) - else .error s!"Int {n} out of bounds for {name}" + if dtype.canCastFromInt n then .ok (dtype.byteArrayOfIntOverflow n) + else .error s!"Int {n} out of bounds for {dtype}" def byteArrayOfInt! (dtype : Dtype) (n : Int) : ByteArray := get! $ byteArrayOfInt dtype n -def byteArrayToInt (dtype : Dtype) (arr : ByteArray) : Err Int := - if dtype.itemsize != arr.size then .error "byte size mismatch" - else .ok $ dtype.order.bytesToInt arr - -def byteArrayToInt! (dtype : Dtype) (arr : ByteArray) : Int := get! $ byteArrayToInt dtype arr - private def byteArrayToIntRoundTrip (dtype : Dtype) (n : Int) : Bool := let res := do let arr <- dtype.byteArrayOfInt n - let n' <- dtype.byteArrayToInt arr + let n' := arr.toInt return n == n' res.toOption.getD false @@ -251,25 +179,22 @@ private def byteArrayToIntRoundTrip (dtype : Dtype) (n : Int) : Bool := #guard int8.byteArrayToIntRoundTrip 127 #guard !int8.byteArrayToIntRoundTrip 255 -private def byteArrayToFloat (dtype : Dtype) (arr : ByteArray) : Err Float := match dtype.name with +private def byteArrayToFloat (dtype : Dtype) (arr : ByteArray) : Err Float := match dtype with | .float64 => if arr.size != 8 then .error "byte size mismatch" else - match dtype.order with - | .littleEndian => .ok $ Float.ofBits arr.toUInt64LE! -- toUInt64LE! is ok here because we already checked the size - | .bigEndian => .ok $ Float.ofBits arr.toUInt64BE! - | .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible" + .ok $ Float.ofBits arr.toUInt64LE! -- toUInt64LE! is ok here because we already checked the size | _ => .error "Illegal type conversion" private def byteArrayToFloat! (dtype : Dtype) (arr : ByteArray) : Float := get! $ byteArrayToFloat dtype arr -private def byteArrayOfFloat (dtype : Dtype) (f : Float) : Err ByteArray := match dtype.name with -| .float64 => .ok $ BV64.toByteArray f.toBits.toBitVec dtype.order +private def byteArrayOfFloat (dtype : Dtype) (f : Float) : Err ByteArray := match dtype with +| .float64 => .ok $ BV64.toByteArray f.toBits.toBitVec | _ => .error "Illegal type conversion" private def byteArrayOfFloat! (dtype : Dtype) (f : Float) : ByteArray := get! $ byteArrayOfFloat dtype f -def byteArrayOfFloat32 (dtype : Dtype) (f : Float32) : Err ByteArray := match dtype.name with -| .float32 => .ok $ BV32.toByteArray f.toBits.toBitVec dtype.order +def byteArrayOfFloat32 (dtype : Dtype) (f : Float32) : Err ByteArray := match dtype with +| .float32 => .ok $ BV32.toByteArray f.toBits.toBitVec | _ => .error "Illegal type conversion" private def byteArrayOfFloat32! (dtype : Dtype) (f : Float32) : ByteArray := get! $ byteArrayOfFloat32 dtype f @@ -288,13 +213,10 @@ private def byteArrayToFloatRoundTrip (dtype : Dtype) (f : Float) : Bool := #guard float64.byteArrayToFloatRoundTrip (Float.sqrt 2) #guard !float32.byteArrayToFloatRoundTrip 0 -def byteArrayToFloat32 (dtype : Dtype) (arr : ByteArray) : Err Float32 := match dtype.name with +def byteArrayToFloat32 (dtype : Dtype) (arr : ByteArray) : Err Float32 := match dtype with | .float32 => if arr.size != 4 then .error "byte size mismatch" else - match dtype.order with - | .littleEndian => .ok $ Float32.ofBits arr.toUInt32LE! - | .bigEndian => .ok $ Float32.ofBits arr.toUInt32BE! - | .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible" + .ok $ Float32.ofBits arr.toUInt32LE! | _ => .error "Illegal type conversion" def byteArrayToFloat32! (dtype : Dtype) (arr : ByteArray) : Float32 := get! $ byteArrayToFloat32 dtype arr @@ -318,21 +240,17 @@ NumPy addition overflows and underflows without complaint. We will do the same. -/ def add (dtype : Dtype) (x y : ByteArray) : Err ByteArray := if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "add: byte size mismatch" else - match dtype.name with + match dtype with | .bool => do - let x <- dtype.byteArrayToNat x - let y <- dtype.byteArrayToNat y + let x := x.toNat + let y := y.toNat if x == 1 || y == 1 then dtype.byteArrayOfInt 1 else if x == 0 && y == 0 then dtype.byteArrayOfInt 0 else .error "illegal bool bytes" | .uint8 | .uint16 | .uint32 | .uint64 => do - let x <- dtype.byteArrayToNat x - let y <- dtype.byteArrayToNat y - return dtype.byteArrayOfNatOverflow (x + y) + return dtype.byteArrayOfNatOverflow (x.toNat + y.toNat) | .int8 | .int16| .int32 | .int64 => do - let x <- dtype.byteArrayToInt x - let y <- dtype.byteArrayToInt y - dtype.byteArrayOfInt (x + y) + dtype.byteArrayOfInt (x.toInt + y.toInt) | .float32 => do let x <- dtype.byteArrayToFloat32 x let y <- dtype.byteArrayToFloat32 y @@ -344,15 +262,11 @@ def add (dtype : Dtype) (x y : ByteArray) : Err ByteArray := def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray := if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "sub: byte size mismatch" else - match dtype.name with + match dtype with | .uint8 | .uint16 | .uint32 | .uint64 => do - let x <- dtype.byteArrayToNat x - let y <- dtype.byteArrayToNat y - return dtype.byteArrayOfNatOverflow (x - y) + return dtype.byteArrayOfNatOverflow (x.toNat - y.toNat) | .int8 | .int16| .int32 | .int64 => do - let x <- dtype.byteArrayToInt x - let y <- dtype.byteArrayToInt y - return dtype.byteArrayOfIntOverflow (x - y) + return dtype.byteArrayOfIntOverflow (x.toInt - y.toInt) | .float32 => do let x <- dtype.byteArrayToFloat32 x let y <- dtype.byteArrayToFloat32 y @@ -361,19 +275,15 @@ def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray := let x <- dtype.byteArrayToFloat x let y <- dtype.byteArrayToFloat y dtype.byteArrayOfFloat (x - y) - | .bool => .error s!"`sub` not supported at type ${dtype.name}" + | .bool => .error s!"`sub` not supported at type ${dtype}" def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray := if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "mul: byte size mismatch" else - match dtype.name with + match dtype with | .uint8 | .uint16 | .uint32 | .uint64 => do - let x <- dtype.byteArrayToNat x - let y <- dtype.byteArrayToNat y - return dtype.byteArrayOfNatOverflow (x * y) + return dtype.byteArrayOfNatOverflow (x.toNat * y.toNat) | .int8 | .int16| .int32 | .int64 => do - let x <- dtype.byteArrayToInt x - let y <- dtype.byteArrayToInt y - return dtype.byteArrayOfIntOverflow (x * y) + return dtype.byteArrayOfIntOverflow (x.toInt * y.toInt) | .float32 => do let x <- dtype.byteArrayToFloat32 x let y <- dtype.byteArrayToFloat32 y @@ -382,19 +292,15 @@ def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray := let x <- dtype.byteArrayToFloat x let y <- dtype.byteArrayToFloat y dtype.byteArrayOfFloat (x * y) - | .bool => .error s!"`mul` not supported at type ${dtype.name}" + | .bool => .error s!"`mul` not supported at type ${dtype}" def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray := if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "div: byte size mismatch" else - match dtype.name with + match dtype with | .uint8 | .uint16 | .uint32 | .uint64 => do - let x <- dtype.byteArrayToNat x - let y <- dtype.byteArrayToNat y - return dtype.byteArrayOfNatOverflow (x / y) + return dtype.byteArrayOfNatOverflow (x.toNat / y.toNat) | .int8 | .int16| .int32 | .int64 => do - let x <- dtype.byteArrayToInt x - let y <- dtype.byteArrayToInt y - return dtype.byteArrayOfIntOverflow (x / y) + return dtype.byteArrayOfIntOverflow (x.toInt / y.toInt) | .float32 => do let x <- dtype.byteArrayToFloat32 x let y <- dtype.byteArrayToFloat32 y @@ -403,24 +309,21 @@ def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray := let x <- dtype.byteArrayToFloat x let y <- dtype.byteArrayToFloat y dtype.byteArrayOfFloat (x / y) - | .bool => .error s!"`div` not supported at type ${dtype.name}" + | .bool => .error s!"`div` not supported at type ${dtype}" /- This works for int/uint/bool/float. Keep an eye out when we start implementing unusual floating point types. -/ def zero (dtype : Dtype) : ByteArray := ByteArray.mk $ (List.replicate dtype.itemsize (0 : UInt8)).toArray -private def castLEOverflow (fromDtype : Dtype) (data : ByteArray) (toDtype : Dtype) : ByteArray := - if fromDtype.order == ByteOrder.bigEndian || toDtype.order == ByteOrder.bigEndian then impossible "needs littleEndian" else +def castOverflow (fromDtype : Dtype) (data : ByteArray) (toDtype : Dtype) : ByteArray := if fromDtype == toDtype then data else - match fromDtype.name, toDtype.name with + match fromDtype, toDtype with | _, .bool => ByteArray.mk #[if data.data.all fun x => x == 0 then 0 else 1] | .bool, _ | .uint8, _ | .uint16, _ | .uint32, _ | .uint64, _ => - let n := ByteOrder.littleEndian.bytesToNat data - toDtype.byteArrayOfNatOverflow n + toDtype.byteArrayOfNatOverflow data.toNat | .int8, _ | .int16, _ | .int32, _ | .int64, _ => - let n := ByteOrder.littleEndian.bytesToInt data - toDtype.byteArrayOfIntOverflow n + toDtype.byteArrayOfIntOverflow data.toInt | .float32, .uint8 | .float32, .uint16 | .float32, .uint32 | .float32, .uint64 => toDtype.byteArrayOfNatOverflow (Float32.ofLEByteArray! data).toNat | .float32, .int8 | .float32, .int16 | .float32, .int32 | .float32, .int64 => @@ -433,17 +336,5 @@ private def castLEOverflow (fromDtype : Dtype) (data : ByteArray) (toDtype : Dty | .float64, .float32 => (Float.ofLEByteArray! data).toFloat32.toLEByteArray | .float32, .float32 | .float64, .float64 => impossible -def castOverflow (fromDtype : Dtype) (data : ByteArray) (toDtype : Dtype) : ByteArray := - match fromDtype.order, toDtype.order with - | .littleEndian, .littleEndian - | .oneByte, .littleEndian - | .littleEndian, .oneByte - | .oneByte, .oneByte => castLEOverflow fromDtype data toDtype - | .bigEndian, .littleEndian - | .bigEndian, .oneByte => castLEOverflow { fromDtype with order := .littleEndian } data.reverse toDtype - | .bigEndian, .bigEndian => (castLEOverflow { fromDtype with order := .littleEndian } data.reverse toDtype).reverse - | .littleEndian, .bigEndian - | .oneByte, .bigEndian => (castLEOverflow fromDtype data toDtype).reverse - end Dtype end TensorLib diff --git a/TensorLib/Npy.lean b/TensorLib/Npy.lean index 39243a3..bdd5473 100644 --- a/TensorLib/Npy.lean +++ b/TensorLib/Npy.lean @@ -48,12 +48,6 @@ deriving BEq, Repr, Inhabited namespace ByteOrder -def toByteOrder (x : ByteOrder) : Option TensorLib.ByteOrder := match x with -| .native => none -| .littleEndian => some .littleEndian -| .bigEndian => some .bigEndian -| .notApplicable => some .oneByte - def toChar (x : ByteOrder) := match x with | native => '=' | littleEndian => '<' @@ -70,7 +64,7 @@ def fromChar (c : Char) : Err ByteOrder := match c with end ByteOrder structure Dtype where - name : TensorLib.Dtype.Name + name : TensorLib.Dtype order : ByteOrder deriving BEq, Repr, Inhabited @@ -80,7 +74,7 @@ namespace Dtype Parse a numpy dtype. The first character represents the byte order: https://numpy.org/doc/2.1/reference/generated/numpy.dtype.byteorder.html -/ -def dtypeNameFromNpyString (s : String) : Err Dtype.Name := match s with +def dtypeNameFromNpyString (s : String) : Err TensorLib.Dtype := match s with | "b1" => .ok .bool | "i1" => .ok .int8 | "i2" => .ok .int16 @@ -94,7 +88,7 @@ def dtypeNameFromNpyString (s : String) : Err Dtype.Name := match s with | "f8" => .ok .float64 | _ => .error s!"Can't parse {s} as a dtype" -def dtypeNameToNpyString (t : Dtype.Name) : String := match t with +def dtypeNameToNpyString (t : TensorLib.Dtype) : String := match t with | .bool => "b1" | .int8 => "i1" | .int16 => "i2" @@ -154,6 +148,8 @@ def nbytes (x : Ndarray) : Nat := x.header.descr.itemsize * x.header.shape.count def dtype (arr : Ndarray) : Dtype := arr.header.descr +def itemsize (arr : Ndarray) : Nat := arr.dtype.itemsize + def order (arr : Ndarray) : ByteOrder := arr.dtype.order end Ndarray diff --git a/TensorLib/Tensor.lean b/TensorLib/Tensor.lean index 409eea2..bb4440d 100644 --- a/TensorLib/Tensor.lean +++ b/TensorLib/Tensor.lean @@ -103,15 +103,10 @@ def ones (dtype : Dtype) (shape : Shape) : Tensor := Id.run do let itemsize := dtype.itemsize let mut data := ByteArray.mkEmpty size for i in [0:size] do - let byte := match dtype.order with - | .oneByte => 1 - | .littleEndian => if i.mod itemsize == 0 then 1 else 0 - | .bigEndian => if i.mod itemsize == itemsize - 1 then 1 else 0 + let byte := if i.mod itemsize == 0 then 1 else 0 data := data.push byte { dtype := dtype, shape := shape, data := data } -def byteOrder (arr : Tensor) : ByteOrder := arr.dtype.order - --! number of dimensions def ndim (x : Tensor) : Nat := x.shape.ndim @@ -157,7 +152,7 @@ where we must have an int/uint Tensor as an argument. def intAtDimIndex (arr : Tensor) (dimIndex : DimIndex) : Err Int := do if !arr.isIntLike then .error "natAt expects an int tensor" else let bytes <- byteArrayAtDimIndex arr dimIndex - .ok $ arr.byteOrder.bytesToInt bytes + .ok $ bytes.toInt /-! Copy a Tensor's data to new, contiguous storage. @@ -554,13 +549,13 @@ def toByteArrayTree (arr : Tensor) : Err (Format.Tree ByteArray) := do def toIntTree (arr : Tensor) : Err (Format.Tree Int) := do let t <- arr.toByteArrayTree - t.mapM arr.dtype.byteArrayToInt + return t.map ByteArray.toInt def toIntTree! (arr : Tensor) : Format.Tree Int := get! $ toIntTree arr def toNatTree (arr : Tensor) : Err (Format.Tree Nat) := do let t <- arr.toByteArrayTree - t.mapM arr.dtype.byteArrayToNat + return t.map ByteArray.toNat def toNatTree! (arr : Tensor) : Format.Tree Nat := get! $ toNatTree arr @@ -572,9 +567,25 @@ def formatNat (arr : Tensor) : Err Std.Format := do let t <- arr.toNatTree t.format -private def dataOfNpy (arr : Npy.Ndarray) : ByteArray := +private def reverseEndianness (arr : ByteArray) (itemsize : Nat) : Err ByteArray := do + if arr.size.mod itemsize != 0 then .error "Bytearray size mismatch" else + let mut res := ByteArray.mkEmpty arr.size + for i in [0:arr.size / itemsize] do + let offset := itemsize * i + let bytes := arr.extract offset (offset + itemsize) + let bytes := bytes.reverse + res := res.append bytes + return res + +private def dataOfNpy (arr : Npy.Ndarray) : Err ByteArray := do let dst := ByteArray.mkEmpty arr.nbytes - arr.data.copySlice arr.startIndex dst 0 arr.nbytes + let copied := arr.data.copySlice arr.startIndex dst 0 arr.nbytes + let res <- match arr.order with + | .notApplicable + | .littleEndian => .ok copied + | .bigEndian => reverseEndianness copied arr.itemsize + | .native => .error "Native byte ordering is not supported. Please force a byte order when you save the array." + return res /- Makes a copy of the data, dropping the header and padding. @@ -582,31 +593,21 @@ Probably not a great choice, but sticking with it for now. I want to avoid writing .npy files with wrong header data. -/ def ofNpy (arr : Npy.Ndarray) : Err Tensor := do - match arr.order.toByteOrder with - | .none => .error "can't convert byte order" - | .some order => - let dtype <- Dtype.make arr.dtype.name order + let dtype := arr.dtype.name let shape := arr.header.shape - let data := dataOfNpy arr + let data <- dataOfNpy arr let startIndex := 0 return { dtype, shape, data, startIndex } -private def dtypeToNpy (dtype : Dtype) : Npy.Dtype := - let order := match dtype.order with - | .bigEndian => .bigEndian - | .littleEndian => .littleEndian - | .oneByte => .notApplicable - Npy.Dtype.mk dtype.name order - /- If we have a non-trivial view, we will need a copy, since strides and start positions are not included in the .npy file format -/ private def toNpy (arr : Tensor) : Npy.Ndarray := let arr := if arr.isTriviallyReshapable then arr else arr.copy - let descr := dtypeToNpy arr.dtype + let descr := Npy.Dtype.mk arr.dtype Npy.ByteOrder.littleEndian let shape := arr.shape - let header : Npy.Header := { descr, shape } + let header : Npy.Header := { descr := descr, shape := shape } let data := arr.data let startIndex := 0 { header, data, startIndex } diff --git a/TensorLib/Test.lean b/TensorLib/Test.lean index 3c0ffb0..ab09675 100644 --- a/TensorLib/Test.lean +++ b/TensorLib/Test.lean @@ -20,7 +20,7 @@ private def saveNumpyArray (expr : String) : IO System.FilePath := do return file.addExtension "npy" private def testTensorElementBV (dtype : Dtype) : IO Bool := do - let file <- saveNumpyArray s!"np.arange(20, dtype='{dtype.name}').reshape(5, 4)" + let file <- saveNumpyArray s!"np.arange(20, dtype='{dtype}').reshape(5, 4)" let npy <- Npy.parseFile file let arr <- IO.ofExcept (Tensor.ofNpy npy) let _ <- IO.FS.removeFile file