Skip to content

Commit

Permalink
feat: support Float32
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmcl committed Feb 28, 2025
1 parent 84a43d1 commit e9d7dd0
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 9 deletions.
41 changes: 41 additions & 0 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,47 @@ def BV64.toByteArray (x : BV64) (ord : ByteOrder) : ByteArray :=
| .oneByte => []
(arr.map BV8.toUInt8).toByteArray

/-
The largest Nat such that it and every smaller Nat can be represented exactly in a 64-bit IEEE-754 float
https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER
The number is one less than 2^(mantissa size)
https://en.wikipedia.org/wiki/Floating-point_arithmetic
https://en.wikipedia.org/wiki/Double-precision_floating-point_format
-/
private def floatMantissaBits : Nat := 52
private def float32MantissaBits : Nat := 23
-- Add 1 to the mantissa length because of the implicit leading 1
def maxSafeNatForFloat : Nat := Nat.pow 2 (floatMantissaBits + 1) - 1
def maxSafeNatForFloat32 : Nat := Nat.pow 2 (float32MantissaBits + 1) - 1

def _root_.Int.toFloat (n : Int) : Float := Float.ofInt n
-- TODO: Use Flaot32.ofInt when https://github.com/leanprover/lean4/pull/7277 is merged, probably in 4.17.0
def _root_.Int.toFloat32 (n : Int) : Float32 := match n with
| Int.ofNat n => Float32.ofNat n
| Int.negSucc n => Float32.neg (Float32.ofNat (Nat.succ n))

def _root_.Float.toLEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.littleEndian
def _root_.Float.toBEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.bigEndian
def _root_.Float32.toLEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.littleEndian
def _root_.Float32.toBEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.bigEndian

/-- Interpret a `ByteArray` of size 4 as a little-endian `UInt32`. Missing from Lean stdlib. -/
def _root_.ByteArray.toUInt32LE! (bs : ByteArray) : UInt32 :=
assert! bs.size == 4
(bs.get! 3).toUInt32 <<< 0x18 |||
(bs.get! 2).toUInt32 <<< 0x10 |||
(bs.get! 1).toUInt32 <<< 0x8 |||
(bs.get! 0).toUInt32

/-- Interpret a `ByteArray` of size 4 as a big-endian `UInt32`. Missing from Lean stdlib. -/
def _root_.ByteArray.toUInt32BE! (bs : ByteArray) : UInt32 :=
assert! bs.size == 4
(bs.get! 0).toUInt32 <<< 0x38 |||
(bs.get! 1).toUInt32 <<< 0x30 |||
(bs.get! 2).toUInt32 <<< 0x28 |||
(bs.get! 3).toUInt32 <<< 0x20

#guard (
let n : BV64 := 0x3FFAB851EB851EB8
do
Expand Down
53 changes: 44 additions & 9 deletions TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,16 @@ def byteArrayOfNat (dtype : Dtype) (n : Nat) : Err ByteArray := match dtype.name
| .int32 => if n <= 0x7FFFFFFF then .ok $ BV32.toByteArray n.toInt32.toBitVec dtype.order else .error "Nat out of bounds for int32"
| .uint64 => if n <= 0xFFFFFFFFFFFFFFFF then .ok $ BV64.toByteArray n.toUInt64.toBitVec dtype.order else .error "Nat out of bounds for uint64"
| .int64 => if n <= 0x7FFFFFFFFFFFFFFF then .ok $ BV64.toByteArray n.toInt64.toBitVec dtype.order else .error "Nat out of bounds for int64"
| .float32 => .error "Sub-word floats are not yet supported by lean"
| .float64 => .error "Float not yet supported"
| .float32 => if maxSafeNatForFloat32 < n then .error "Nat may not be exactly reprsentable by Float32" else
match dtype.order with
| .littleEndian => .ok n.toFloat32.toLEByteArray
| .bigEndian => .ok n.toFloat32.toBEByteArray
| .oneByte => impossible -- implies an illegal dtype, which should be impossible
| .float64 => if maxSafeNatForFloat < n then .error "Nat may not be exactly reprsentable by Float32" else
match dtype.order with
| .littleEndian => .ok n.toFloat.toLEByteArray
| .bigEndian => .ok n.toFloat.toBEByteArray
| .oneByte => impossible -- implies an illegal dtype, which should be impossible

def byteArrayOfNat! (dtype : Dtype) (n : Nat) : ByteArray := get! $ byteArrayOfNat dtype n

Expand Down Expand Up @@ -183,16 +191,27 @@ def byteArrayToFloat (dtype : Dtype) (arr : ByteArray) : Err Float := match dtyp
if arr.size != 8 then .error "byte size mismatch" else
match dtype.order with
| .littleEndian => .ok $ Float.ofBits arr.toUInt64LE! -- toUInt64LE! is ok here because we already checked the size
| .bigEndian => .ok $ Float.ofBits arr.toUInt64BE! -- toUInt64BE! is ok here because we already checked the size
| .bigEndian => .ok $ Float.ofBits arr.toUInt64BE!
| .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible"
| _ => .error "Illegal type conversion"

def byteArrayToFloat32 (dtype : Dtype) (arr : ByteArray) : Err Float32 := match dtype.name with
| .float32 =>
if arr.size != 4 then .error "byte size mismatch" else
match dtype.order with
| .littleEndian => .ok $ Float32.ofBits arr.toUInt32LE!
| .bigEndian => .ok $ Float32.ofBits arr.toUInt32BE!
| .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible"
| .float32 => .error "Unsupported float type. Requires Lean support."
| _ => .error "Illegal type conversion"

def byteArrayToFloat! (dtype : Dtype) (arr : ByteArray) : Float := get! $ byteArrayToFloat dtype arr

def byteArrayOfFloat (dtype : Dtype) (f : Float) : Err ByteArray := match dtype.name with
| .float64 => .ok $ BV64.toByteArray f.toBits.toBitVec dtype.order
| .float32 => .error "Unsupported float type. Requires Lean support."
| _ => .error "Illegal type conversion"

def byteArrayOfFloat32 (dtype : Dtype) (f : Float32) : Err ByteArray := match dtype.name with
| .float32 => .ok $ BV32.toByteArray f.toBits.toBitVec dtype.order
| _ => .error "Illegal type conversion"

def byteArrayOfFloat! (dtype : Dtype) (f : Float) : ByteArray := get! $ byteArrayOfFloat dtype f
Expand Down Expand Up @@ -266,11 +285,15 @@ def add (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToInt dtype x
let y <- byteArrayToInt dtype y
byteArrayOfInt dtype (x + y)
| .float32 => do
let x <- byteArrayToFloat32 dtype x
let y <- byteArrayToFloat32 dtype y
byteArrayOfFloat32 dtype (x + y)
| .float64 => do
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x + y)
| .bool | .float32 => .error s!"`add` not supported at type ${dtype.name}"
| .bool => .error s!"`add` not supported at type ${dtype.name}"

def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "sub: byte size mismatch" else
Expand All @@ -283,11 +306,15 @@ def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToInt dtype x
let y <- byteArrayToInt dtype y
byteArrayOfInt dtype (x - y)
| .float32 => do
let x <- byteArrayToFloat32 dtype x
let y <- byteArrayToFloat32 dtype y
byteArrayOfFloat32 dtype (x - y)
| .float64 => do
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x - y)
| .bool | .float32 => .error s!"`sub` not supported at type ${dtype.name}"
| .bool => .error s!"`sub` not supported at type ${dtype.name}"

def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "mul: byte size mismatch" else
Expand All @@ -300,11 +327,15 @@ def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToInt dtype x
let y <- byteArrayToInt dtype y
byteArrayOfInt dtype (x * y)
| .float32 => do
let x <- byteArrayToFloat32 dtype x
let y <- byteArrayToFloat32 dtype y
byteArrayOfFloat32 dtype (x * y)
| .float64 => do
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x * y)
| .bool | .float32 => .error s!"`mul` not supported at type ${dtype.name}"
| .bool => .error s!"`mul` not supported at type ${dtype.name}"

def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "div: byte size mismatch" else
Expand All @@ -317,11 +348,15 @@ def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToInt dtype x
let y <- byteArrayToInt dtype y
byteArrayOfInt dtype (x / y)
| .float32 => do
let x <- byteArrayToFloat32 dtype x
let y <- byteArrayToFloat32 dtype y
byteArrayOfFloat32 dtype (x / y)
| .float64 => do
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x / y)
| .bool | .float32 => .error s!"`div` not supported at type ${dtype.name}"
| .bool => .error s!"`div` not supported at type ${dtype.name}"

/-
This works for int/uint/bool/float. Keep an eye out when we start implementing unusual floating point types.
Expand Down

0 comments on commit e9d7dd0

Please sign in to comment.