Skip to content

Commit

Permalink
chore: Remove float16
Browse files Browse the repository at this point in the history
There is no short term plan to have fp16 support in Lean.
Let's remove for now. We have fp32 support in 4.16, so I'll
implement that next. After that we won't have any TODO types.
  • Loading branch information
seanmcl committed Feb 27, 2025
1 parent e1a1f2f commit 84a43d1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
20 changes: 7 additions & 13 deletions TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ inductive Name where
| uint16
| uint32
| uint64
| float16
| float32
| float64
deriving BEq, Repr, Inhabited
Expand All @@ -39,7 +38,6 @@ instance : ToString Name where
| uint16 => "uint16"
| uint32 => "uint32"
| uint64 => "uint64"
| float16 => "float16"
| float32 => "float32"
| float64 => "float64"

Expand All @@ -63,7 +61,7 @@ def isIntLike (x : Name) : Bool := x.isInt || x.isUint
def itemsize (x : Name) : Nat := match x with
| float64 | int64 | uint64 => 8
| float32 | int32 | uint32 => 4
| float16 | int16 | uint16 => 2
| int16 | uint16 => 2
| bool | int8 | uint8 => 1

end Name
Expand All @@ -90,7 +88,6 @@ def uint8 : Dtype := Dtype.mk .uint8 .oneByte
def uint16 : Dtype := Dtype.mk .uint16 .littleEndian
def uint32 : Dtype := Dtype.mk .uint32 .littleEndian
def uint64 : Dtype := Dtype.mk .uint64 .littleEndian
def float16 : Dtype := Dtype.mk .float16 .littleEndian
def float32 : Dtype := Dtype.mk .float32 .littleEndian
def float64 : Dtype := Dtype.mk .float64 .littleEndian

Expand Down Expand Up @@ -125,7 +122,6 @@ def byteArrayOfNat (dtype : Dtype) (n : Nat) : Err ByteArray := match dtype.name
| .int32 => if n <= 0x7FFFFFFF then .ok $ BV32.toByteArray n.toInt32.toBitVec dtype.order else .error "Nat out of bounds for int32"
| .uint64 => if n <= 0xFFFFFFFFFFFFFFFF then .ok $ BV64.toByteArray n.toUInt64.toBitVec dtype.order else .error "Nat out of bounds for uint64"
| .int64 => if n <= 0x7FFFFFFFFFFFFFFF then .ok $ BV64.toByteArray n.toInt64.toBitVec dtype.order else .error "Nat out of bounds for int64"
| .float16
| .float32 => .error "Sub-word floats are not yet supported by lean"
| .float64 => .error "Float not yet supported"

Expand Down Expand Up @@ -159,7 +155,6 @@ def byteArrayOfInt (dtype : Dtype) (n : Int) : Err ByteArray := match dtype.name
| .int32 => if -0x80000000 <= n && n <= 0x7FFFFFFF then .ok $ BV32.toByteArray n.toInt32.toBitVec dtype.order else .error "out of bounds"
| .uint64
| .int64 => if -0x8000000000000000 <= n && n <= 0x7FFFFFFFFFFFFFFF then .ok $ BV64.toByteArray n.toInt64.toBitVec dtype.order else .error "out of bounds"
| .float16
| .float32 => .error "Sub-word floats are not yet supported by lean"
| .float64 => .error "Float not yet supported"

Expand Down Expand Up @@ -190,14 +185,14 @@ def byteArrayToFloat (dtype : Dtype) (arr : ByteArray) : Err Float := match dtyp
| .littleEndian => .ok $ Float.ofBits arr.toUInt64LE! -- toUInt64LE! is ok here because we already checked the size
| .bigEndian => .ok $ Float.ofBits arr.toUInt64BE! -- toUInt64BE! is ok here because we already checked the size
| .oneByte => impossible "Illegal dtype. Creation shouldn't have been possible"
| .float16 | .float32 => .error "Unsupported float type. Requires Lean support."
| .float32 => .error "Unsupported float type. Requires Lean support."
| _ => .error "Illegal type conversion"

def byteArrayToFloat! (dtype : Dtype) (arr : ByteArray) : Float := get! $ byteArrayToFloat dtype arr

def byteArrayOfFloat (dtype : Dtype) (f : Float) : Err ByteArray := match dtype.name with
| .float64 => .ok $ BV64.toByteArray f.toBits.toBitVec dtype.order
| .float16 | .float32 => .error "Unsupported float type. Requires Lean support."
| .float32 => .error "Unsupported float type. Requires Lean support."
| _ => .error "Illegal type conversion"

def byteArrayOfFloat! (dtype : Dtype) (f : Float) : ByteArray := get! $ byteArrayOfFloat dtype f
Expand All @@ -214,7 +209,6 @@ private def byteArrayToFloatRoundTrip (dtype : Dtype) (f : Float) : Bool :=
#guard float64.byteArrayToFloatRoundTrip (-0)
#guard float64.byteArrayToFloatRoundTrip 17
#guard float64.byteArrayToFloatRoundTrip (Float.sqrt 2)
#guard !float16.byteArrayToFloatRoundTrip 0
#guard !float32.byteArrayToFloatRoundTrip 0

private def floatToByteArray (f : Float) : Array UInt8 :=
Expand Down Expand Up @@ -276,7 +270,7 @@ def add (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x + y)
| .bool | .float16 | .float32 => .error s!"`add` not supported at type ${dtype.name}"
| .bool | .float32 => .error s!"`add` not supported at type ${dtype.name}"

def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "sub: byte size mismatch" else
Expand All @@ -293,7 +287,7 @@ def sub (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x - y)
| .bool | .float16 | .float32 => .error s!"`sub` not supported at type ${dtype.name}"
| .bool | .float32 => .error s!"`sub` not supported at type ${dtype.name}"

def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "mul: byte size mismatch" else
Expand All @@ -310,7 +304,7 @@ def mul (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x * y)
| .bool | .float16 | .float32 => .error s!"`mul` not supported at type ${dtype.name}"
| .bool | .float32 => .error s!"`mul` not supported at type ${dtype.name}"

def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
if dtype.itemsize != x.size || dtype.itemsize != y.size then .error "div: byte size mismatch" else
Expand All @@ -327,7 +321,7 @@ def div (dtype : Dtype) (x y : ByteArray) : Err ByteArray :=
let x <- byteArrayToFloat dtype x
let y <- byteArrayToFloat dtype y
byteArrayOfFloat dtype (x / y)
| .bool | .float16 | .float32 => .error s!"`div` not supported at type ${dtype.name}"
| .bool | .float32 => .error s!"`div` not supported at type ${dtype.name}"

/-
This works for int/uint/bool/float. Keep an eye out when we start implementing unusual floating point types.
Expand Down
2 changes: 0 additions & 2 deletions TensorLib/Npy.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def dtypeNameFromNpyString (s : String) : Err Dtype.Name := match s with
| "u2" => .ok .uint16
| "u4" => .ok .uint32
| "u8" => .ok .uint64
| "f2" => .ok .float16
| "f4" => .ok .float32
| "f8" => .ok .float64
| _ => .error s!"Can't parse {s} as a dtype"
Expand All @@ -105,7 +104,6 @@ def dtypeNameToNpyString (t : Dtype.Name) : String := match t with
| .uint16 => "u2"
| .uint32 => "u3"
| .uint64 => "u4"
| .float16 => "f2"
| .float32 => "f4"
| .float64 => "f8"

Expand Down

0 comments on commit 84a43d1

Please sign in to comment.