From e9d7dd0da9036a25df56851617e39434e616e64a Mon Sep 17 00:00:00 2001 From: Sean McLaughlin Date: Thu, 27 Feb 2025 13:10:52 -0800 Subject: [PATCH] feat: support Float32 --- TensorLib/Common.lean | 41 +++++++++++++++++++++++++++++++++ TensorLib/Dtype.lean | 53 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 85 insertions(+), 9 deletions(-) diff --git a/TensorLib/Common.lean b/TensorLib/Common.lean index e9ec5d2..f61dbfa 100644 --- a/TensorLib/Common.lean +++ b/TensorLib/Common.lean @@ -685,6 +685,47 @@ def BV64.toByteArray (x : BV64) (ord : ByteOrder) : ByteArray := | .oneByte => [] (arr.map BV8.toUInt8).toByteArray +/- +The largest Nat such that it and every smaller Nat can be represented exactly in a 64-bit IEEE-754 float +https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER + +The number is one less than 2^(mantissa size) +https://en.wikipedia.org/wiki/Floating-point_arithmetic +https://en.wikipedia.org/wiki/Double-precision_floating-point_format +-/ +private def floatMantissaBits : Nat := 52 +private def float32MantissaBits : Nat := 23 +-- Add 1 to the mantissa length because of the implicit leading 1 +def maxSafeNatForFloat : Nat := Nat.pow 2 (floatMantissaBits + 1) - 1 +def maxSafeNatForFloat32 : Nat := Nat.pow 2 (float32MantissaBits + 1) - 1 + +def _root_.Int.toFloat (n : Int) : Float := Float.ofInt n +-- TODO: Use Flaot32.ofInt when https://github.com/leanprover/lean4/pull/7277 is merged, probably in 4.17.0 +def _root_.Int.toFloat32 (n : Int) : Float32 := match n with +| Int.ofNat n => Float32.ofNat n +| Int.negSucc n => Float32.neg (Float32.ofNat (Nat.succ n)) + +def _root_.Float.toLEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.littleEndian +def _root_.Float.toBEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.bigEndian +def _root_.Float32.toLEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.littleEndian +def _root_.Float32.toBEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.bigEndian + +/-- Interpret a `ByteArray` of size 4 as a little-endian `UInt32`. Missing from Lean stdlib. -/ +def _root_.ByteArray.toUInt32LE! (bs : ByteArray) : UInt32 := + assert! bs.size == 4 + (bs.get! 3).toUInt32 <<< 0x18 ||| + (bs.get! 2).toUInt32 <<< 0x10 ||| + (bs.get! 1).toUInt32 <<< 0x8 ||| + (bs.get! 0).toUInt32 + +/-- Interpret a `ByteArray` of size 4 as a big-endian `UInt32`. Missing from Lean stdlib. -/ +def _root_.ByteArray.toUInt32BE! (bs : ByteArray) : UInt32 := + assert! bs.size == 4 + (bs.get! 0).toUInt32 <<< 0x38 ||| + (bs.get! 1).toUInt32 <<< 0x30 ||| + (bs.get! 2).toUInt32 <<< 0x28 ||| + (bs.get! 3).toUInt32 <<< 0x20 + #guard ( let n : BV64 := 0x3FFAB851EB851EB8 do diff --git a/TensorLib/Dtype.lean b/TensorLib/Dtype.lean index dd8b69c..f8e5f2d 100644 --- a/TensorLib/Dtype.lean +++ b/TensorLib/Dtype.lean @@ -122,8 +122,16 @@ def byteArrayOfNat (dtype : Dtype) (n : Nat) : Err ByteArray := match dtype.name | .int32 => if n <= 0x7FFFFFFF then .ok $ BV32.toByteArray n.toInt32.toBitVec dtype.order else .error "Nat out of bounds for int32" | .uint64 => if n <= 0xFFFFFFFFFFFFFFFF then .ok $ BV64.toByteArray n.toUInt64.toBitVec dtype.order else .error "Nat out of bounds for uint64" | .int64 => if n <= 0x7FFFFFFFFFFFFFFF then .ok $ BV64.toByteArray n.toInt64.toBitVec dtype.order else .error "Nat out of bounds for int64" -| .float32 => .error "Sub-word floats are not yet supported by lean" -| .float64 => .error "Float not yet supported" +| .float32 => if maxSafeNatForFloat32 < n then .error "Nat may not be exactly reprsentable by Float32" else + match dtype.order with + | .littleEndian => .ok n.toFloat32.toLEByteArray + | .bigEndian => .ok n.toFloat32.toBEByteArray + | .oneByte => impossible -- implies an illegal dtype, which should be impossible +| .float64 => if maxSafeNatForFloat < n then .error "Nat may not be exactly reprsentable by Float32" else + match dtype.order with + | .littleEndian => .ok n.toFloat.toLEByteArray + | .bigEndian => .ok n.toFloat.toBEByteArray + | .oneByte => impossible -- implies an illegal dtype, which should be impossible def byteArrayOfNat! (dtype : Dtype) (n : Nat) : ByteArray := get! $ byteArrayOfNat dtype n @@ -183,16 +191,27 @@ def byteArrayToFloat (dtype : Dtype) (arr : ByteArray) : Err Float := match dtyp if arr.size != 8 then .error "byte size mismatch" else match dtype.order with | .littleEndian => .ok $ Float.ofBits arr.toUInt64LE! -- toUInt64LE! is ok here because we already checked the size - | .bigEndian => .ok $ Float.ofBits arr.toUInt64BE! -- toUInt64BE! is ok here because we already checked the size + | .bigEndian => .ok $ Float.ofBits arr.toUInt64BE! + | .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible" +| _ => .error "Illegal type conversion" + +def byteArrayToFloat32 (dtype : Dtype) (arr : ByteArray) : Err Float32 := match dtype.name with +| .float32 => + if arr.size != 4 then .error "byte size mismatch" else + match dtype.order with + | .littleEndian => .ok $ Float32.ofBits arr.toUInt32LE! + | .bigEndian => .ok $ Float32.ofBits arr.toUInt32BE! | .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible" -| .float32 => .error "Unsupported float type. Requires Lean support." | _ => .error "Illegal type conversion" def byteArrayToFloat! (dtype : Dtype) (arr : ByteArray) : Float := get! $ byteArrayToFloat dtype arr def byteArrayOfFloat (dtype : Dtype) (f : Float) : Err ByteArray := match dtype.name with | .float64 => .ok $ BV64.toByteArray f.toBits.toBitVec dtype.order -| .float32 => .error "Unsupported float type. Requires Lean support." +| _ => .error "Illegal type conversion" + +def byteArrayOfFloat32 (dtype : Dtype) (f : Float32) : Err ByteArray := match dtype.name with +| .float32 => .ok $ BV32.toByteArray f.toBits.toBitVec dtype.order | _ => .error "Illegal type conversion" def byteArrayOfFloat! (dtype : Dtype) (f : Float) : ByteArray := get! $ byteArrayOfFloat dtype f @@ -266,11 +285,15 @@ def add (dtype : Dtype) (x y : ByteArray) : Err ByteArray := let x <- byteArrayToInt dtype x let y <- byteArrayToInt dtype y byteArrayOfInt dtype (x + y) + | .float32 => do + let x <- byteArrayToFloat32 dtype x + let y <- byteArrayToFloat32 dtype y + byteArrayOfFloat32 dtype (x + y) | .float64 => do let x <- byteArrayToFloat dtype x let y <- byteArrayToFloat dtype y byteArrayOfFloat dtype (x + y) - | .bool | .float32 => .error s!"`add` not supported at type ${dtype.name}" + | .bool => .error s!"`add` not supported at type ${dtype.name}" def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray := if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "sub: byte size mismatch" else @@ -283,11 +306,15 @@ def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray := let x <- byteArrayToInt dtype x let y <- byteArrayToInt dtype y byteArrayOfInt dtype (x - y) + | .float32 => do + let x <- byteArrayToFloat32 dtype x + let y <- byteArrayToFloat32 dtype y + byteArrayOfFloat32 dtype (x - y) | .float64 => do let x <- byteArrayToFloat dtype x let y <- byteArrayToFloat dtype y byteArrayOfFloat dtype (x - y) - | .bool | .float32 => .error s!"`sub` not supported at type ${dtype.name}" + | .bool => .error s!"`sub` not supported at type ${dtype.name}" def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray := if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "mul: byte size mismatch" else @@ -300,11 +327,15 @@ def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray := let x <- byteArrayToInt dtype x let y <- byteArrayToInt dtype y byteArrayOfInt dtype (x * y) + | .float32 => do + let x <- byteArrayToFloat32 dtype x + let y <- byteArrayToFloat32 dtype y + byteArrayOfFloat32 dtype (x * y) | .float64 => do let x <- byteArrayToFloat dtype x let y <- byteArrayToFloat dtype y byteArrayOfFloat dtype (x * y) - | .bool | .float32 => .error s!"`mul` not supported at type ${dtype.name}" + | .bool => .error s!"`mul` not supported at type ${dtype.name}" def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray := if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "div: byte size mismatch" else @@ -317,11 +348,15 @@ def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray := let x <- byteArrayToInt dtype x let y <- byteArrayToInt dtype y byteArrayOfInt dtype (x / y) + | .float32 => do + let x <- byteArrayToFloat32 dtype x + let y <- byteArrayToFloat32 dtype y + byteArrayOfFloat32 dtype (x / y) | .float64 => do let x <- byteArrayToFloat dtype x let y <- byteArrayToFloat dtype y byteArrayOfFloat dtype (x / y) - | .bool | .float32 => .error s!"`div` not supported at type ${dtype.name}" + | .bool => .error s!"`div` not supported at type ${dtype.name}" /- This works for int/uint/bool/float. Keep an eye out when we start implementing unusual floating point types.