Skip to content

Commit

Permalink
chore: Make all arrays little-endian
Browse files Browse the repository at this point in the history
Endian-ness is a nice feature in NumPy; it's clear what the data should
represent, it handles operations on arrays with different orderings, etc.
But it's largely just an optimization. In our code, which is slow to
begin with, we do not need this optimization. We remove considerable
complexity by forcing all arrays in TensorLib to be little-endian.

One outcome of this is that we will save all .npy files as littleendian
and rely on the Python programmer to convert if necessary.
  • Loading branch information
seanmcl committed Mar 4, 2025
1 parent 8c131d0 commit 9e794a9
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 330 deletions.
171 changes: 48 additions & 123 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -74,97 +74,49 @@ instance ByteArrayRepr : Repr ByteArray where

def _root_.ByteArray.reverse (arr : ByteArray) : ByteArray := ⟨ arr.data.reverse ⟩

/-!
NumPy arrays can be stored in big-endian or little-endian order on disk, regardless
of the architecture of the machine saving the array. Since we read these arrays
into memory at certain data types, for multi-byte data types we need to know the
endian-ness.
-/
inductive ByteOrder where
| oneByte
| littleEndian
| bigEndian
deriving BEq, Repr, Inhabited

namespace ByteOrder

@[simp]
def isMultiByte (x : ByteOrder) : Bool := match x with
| .oneByte => false
| .littleEndian | .bigEndian => true

def bytesToNat (order : ByteOrder) (bytes : ByteArray) : Nat := Id.run do
def _root_.ByteArray.toNat (arr : ByteArray) : Nat := Id.run do
let mut n : Nat := 0
let nbytes := bytes.size
let nbytes := arr.size
for i in [0:nbytes] do
let v : UInt8 := bytes.get! i
let p := match order with
| .oneByte => 0 -- nbytes = 1
| .littleEndian => i
| .bigEndian => nbytes - 1 - i
n := n + Pow.pow 2 (8 * p) * v.toNat
let v : UInt8 := arr.get! i
n := n + Pow.pow 2 (8 * i) * v.toNat
return n

#guard bytesToNat .littleEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToNat .bigEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToNat .littleEndian (ByteArray.mk #[0, 1]) == 256
#guard bytesToNat .bigEndian (ByteArray.mk #[0, 1]) == 1
#guard bytesToNat .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == 65535
#guard bytesToNat .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == 65535
#guard bytesToNat .bigEndian (ByteArray.mk #[0x80, 0]) == 32768
#guard bytesToNat .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80
#guard (ByteArray.mk #[1, 1]).toNat == 257
#guard (ByteArray.mk #[0, 1]).toNat == 256
#guard (ByteArray.mk #[0xFF, 0xFF]).toNat == 65535
#guard (ByteArray.mk #[0, 0x80]).toNat == 32768
#guard (ByteArray.mk #[0x80, 0]).toNat == 0x80

def bytesToInt (order : ByteOrder) (bytes : ByteArray) : Int := Id.run do
def _root_.ByteArray.toInt (arr : ByteArray) : Int := Id.run do
let mut n : Nat := 0
let nbytes := bytes.size
let signByte := match order with
| .littleEndian => bytes.get! (nbytes - 1)
| .bigEndian | oneByte => bytes.get! 0
let nbytes := arr.size
let signByte := arr.get! (nbytes - 1)
let negative := 128 <= signByte
for i in [0:nbytes] do
let v : UInt8 := bytes.get! i
let v : UInt8 := arr.get! i
let v := if negative then UInt8.complement v else v
let p := match order with
| .oneByte => 0 -- nbytes = 1
| .littleEndian => i
| .bigEndian => nbytes - 1 - i
n := n + Pow.pow 2 (8 * p) * v.toNat
n := n + Pow.pow 2 (8 * i) * v.toNat
return if 128 <= signByte then -(n + 1) else n

#guard bytesToInt .littleEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .bigEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .littleEndian (ByteArray.mk #[0, 1]) == 256
#guard bytesToInt .bigEndian (ByteArray.mk #[0, 1]) == 1
#guard bytesToInt .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0x80, 0]) == -32768
#guard bytesToInt .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80
#guard (ByteArray.mk #[1, 1]).toInt == 257
#guard (ByteArray.mk #[0, 1]).toInt == 256
#guard (ByteArray.mk #[1, 0]).toInt == 1
#guard (ByteArray.mk #[0xFF, 0xFF]).toInt == -1
#guard (ByteArray.mk #[0, 0x80]).toInt == -32768
#guard (ByteArray.mk #[0x80, 0]).toInt == 0x80

def bitVecToByteArray (order : ByteOrder) (n : Nat) (v : BitVec n) : ByteArray := Id.run do
def bitVecToByteArray (n : Nat) (v : BitVec n) : ByteArray := Id.run do
let numBytes := natDivCeil n 8
let mut arr := ByteArray.mkEmpty numBytes
match order with
| .oneByte =>
let byte := (v &&& 0xFF).toNat.toUInt8
return arr.push byte
| .littleEndian =>
for i in [0 : numBytes] do
let byte := (v.ushiftRight (i * 8) &&& 0xFF).toNat.toUInt8
arr := arr.push byte
return arr
| .bigEndian =>
for i in [0 : numBytes] do
let byte := (v.ushiftRight ((numBytes - i - 1) * 8) &&& 0xFF).toNat.toUInt8
arr := arr.push byte
return arr

#guard (bitVecToByteArray .bigEndian 16 0x0100).toList == [1, 0]
#guard (bitVecToByteArray .littleEndian 16 0x0100).toList == [0, 1]
#guard (bitVecToByteArray .bigEndian 20 0x01000).toList == [0, 16, 0]
#guard (bitVecToByteArray .littleEndian 32 0x1).toList == [1, 0, 0, 0]
#guard (bitVecToByteArray .bigEndian 32 0x1).toList == [0, 0, 0, 1]

end ByteOrder
for i in [0 : numBytes] do
let byte := (v.ushiftRight (i * 8) &&& 0xFF).toNat.toUInt8
arr := arr.push byte
return arr

#guard (bitVecToByteArray 16 0x0100).toList == [0, 1]
#guard (bitVecToByteArray 20 0x01000).toList == [0, 16, 0]
#guard (bitVecToByteArray 32 0x1).toList == [1, 0, 0, 0]

/-!
The strides are how many bytes you need to skip to get to the next element in that
Expand Down Expand Up @@ -492,7 +444,7 @@ def BV8.ofNat (i : Nat) : BV8 := i.toUInt8.toBitVec
def _root_.UInt8.toBV8 (n : UInt8) : BV8 := BitVec.ofFin n.val
def BV8.toUInt8 (n : BV8) : UInt8 := UInt8.ofNat n.toFin

def BV8.toByteArray (x : BV8) : ByteArray := [x.toUInt8].toByteArray
def BV8.toByteArray (x : BV8) : ByteArray := bitVecToByteArray 8 x

def _root_.ByteArray.toBV8 (x : ByteArray) (startIndex : Nat) : Err BV8 :=
let n := startIndex
Expand All @@ -504,6 +456,9 @@ def _root_.ByteArray.toBV8 (x : ByteArray) (startIndex : Nat) : Err BV8 :=

abbrev BV16 := BitVec 16

def BV16.toByteArray (x : BV16) : ByteArray := bitVecToByteArray 16 x


def BV16.toBytes (n : BV16) : BV8 × BV8 :=
let n0 := (n >>> 0o00 &&& 0xFF).truncate 8
let n1 := (n >>> 0o10 &&& 0xFF).truncate 8
Expand All @@ -528,29 +483,20 @@ theorem BV16.BytesRoundTrip1 (x0 x1 : BV8) :
unfold BV16.toBytes BV16.ofBytes
bv_decide

def _root_.ByteArray.toBV16 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV16 :=
def _root_.ByteArray.toBV16 (x : ByteArray) (startIndex : Nat) : Err BV16 :=
let n := startIndex
if H7 : n + 1 < x.size then
let H0 : n + 0 < x.size := by omega
let H1 : n + 1 < x.size := by omega
let x0 := x.get (Fin.mk _ H0)
let x1 := x.get (Fin.mk _ H1)
match order with
| .oneByte => .error "illegal byte order"
| .littleEndian => .ok (BV16.ofBytes x0.toBV8 x1.toBV8)
| .bigEndian => .ok (BV16.ofBytes x1.toBV8 x0.toBV8)
.ok (BV16.ofBytes x0.toBV8 x1.toBV8)
else .error s!"Index out of range: {n}"

def BV16.toByteArray (x : BV16) (ord : ByteOrder) : ByteArray :=
let (x0, x1) := x.toBytes
let arr := match ord with
| .littleEndian => [x0, x1]
| .bigEndian => [x1, x0]
| .oneByte => [] -- Avoid Err for now
(arr.map BV8.toUInt8).toByteArray

abbrev BV32 := BitVec 32

def BV32.toByteArray (x : BV32) : ByteArray := bitVecToByteArray 32 x

def BV32.toBytes (n : BV32) : BV8 × BV8 × BV8 × BV8 :=
let n0 := (n >>> 0o00 &&& 0xFF).truncate 8
let n1 := (n >>> 0o10 &&& 0xFF).truncate 8
Expand Down Expand Up @@ -581,7 +527,7 @@ theorem BV32.BytesRoundTrip1 (x0 x1 x2 x3 : BV8) :
unfold BV32.toBytes BV32.ofBytes
bv_decide

def _root_.ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV32 :=
def _root_.ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) : Err BV32 :=
let n := startIndex
if H7 : n + 3 < x.size then
let H0 : n + 0 < x.size := by omega
Expand All @@ -592,22 +538,13 @@ def _root_.ByteArray.toBV32 (x : ByteArray) (startIndex : Nat) (order : ByteOrde
let x1 := x.get (Fin.mk _ H1)
let x2 := x.get (Fin.mk _ H2)
let x3 := x.get (Fin.mk _ H3)
match order with
| .oneByte => .error "illegal byte order"
| .littleEndian => .ok (BV32.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8)
| .bigEndian => .ok (BV32.ofBytes x3.toBV8 x2.toBV8 x1.toBV8 x0.toBV8)
.ok (BV32.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8)
else .error s!"Index out of range: {n}"

def BV32.toByteArray (x : BV32) (ord : ByteOrder) : ByteArray :=
let (x0, x1, x2, x3) := x.toBytes
let arr := match ord with
| .littleEndian => [x0, x1, x2, x3]
| .bigEndian => [x3, x2, x1, x0]
| .oneByte => []
(arr.map BV8.toUInt8).toByteArray

abbrev BV64 := BitVec 64

def BV64.toByteArray (x : BV64) : ByteArray := bitVecToByteArray 64 x

def BV64.ofNat (i : Nat) : BV64 := i.toUInt64.toBitVec

def BV64.ofInt (i : Int) : BV64 := i.toInt64.toBitVec
Expand Down Expand Up @@ -654,7 +591,7 @@ theorem BV64.BytesRoundTrip1 (x0 x1 x2 x3 x4 x5 x6 x7 : BV8) :
unfold BV64.toBytes BV64.ofBytes
bv_decide

def _root_.ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) (order : ByteOrder) : Err BV64 :=
def _root_.ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) : Err BV64 :=
let n := startIndex
if H7 : n + 7 < x.size then
let H0 : n + 0 < x.size := by omega
Expand All @@ -672,20 +609,9 @@ def _root_.ByteArray.toBV64 (x : ByteArray) (startIndex : Nat) (order : ByteOrde
let x5 := x.get (Fin.mk _ H5)
let x6 := x.get (Fin.mk _ H6)
let x7 := x.get (Fin.mk _ H7)
match order with
| .oneByte => .error "illegal byte order"
| .littleEndian => .ok (BV64.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8 x4.toBV8 x5.toBV8 x6.toBV8 x7.toBV8)
| .bigEndian => .ok (BV64.ofBytes x7.toBV8 x6.toBV8 x5.toBV8 x4.toBV8 x3.toBV8 x2.toBV8 x1.toBV8 x0.toBV8)
.ok (BV64.ofBytes x0.toBV8 x1.toBV8 x2.toBV8 x3.toBV8 x4.toBV8 x5.toBV8 x6.toBV8 x7.toBV8)
else .error s!"Index out of range: {n}"

def BV64.toByteArray (x : BV64) (ord : ByteOrder) : ByteArray :=
let (x0, x1, x2, x3, x4, x5, x6, x7) := x.toBytes
let arr := match ord with
| .littleEndian => [x0, x1, x2, x3, x4, x5, x6, x7]
| .bigEndian => [x7, x6, x5, x4, x3, x2, x1, x0]
| .oneByte => []
(arr.map BV8.toUInt8).toByteArray

/-
The largest Nat such that it and every smaller Nat can be represented exactly in a 64-bit IEEE-754 float
https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER
Expand All @@ -706,10 +632,9 @@ def _root_.Int.toFloat32 (n : Int) : Float32 := match n with
| Int.ofNat n => Float32.ofNat n
| Int.negSucc n => Float32.neg (Float32.ofNat (Nat.succ n))

def _root_.Float.toLEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.littleEndian
def _root_.Float.toBEByteArray (f : Float) : ByteArray := BV64.toByteArray f.toBits.toBitVec ByteOrder.bigEndian
def _root_.Float32.toLEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.littleEndian
def _root_.Float32.toBEByteArray (f : Float32) : ByteArray := BV32.toByteArray f.toBits.toBitVec ByteOrder.bigEndian

def _root_.Float32.toLEByteArray (f : Float32) : ByteArray := bitVecToByteArray 32 f.toBits.toBitVec
def _root_.Float.toLEByteArray (f : Float) : ByteArray := bitVecToByteArray 64 f.toBits.toBitVec

/-- Interpret a `ByteArray` of size 4 as a little-endian `UInt32`. Missing from Lean stdlib. -/
def _root_.ByteArray.toUInt32LE! (bs : ByteArray) : UInt32 :=
Expand Down Expand Up @@ -754,8 +679,8 @@ def _root_.Float.toInt (f : Float) : Int :=
#guard (
let n : BV64 := 0x3FFAB851EB851EB8
do
let arr := n.toByteArray .littleEndian
let n' <- ByteArray.toBV64 arr 0 .littleEndian
let arr := n.toByteArray
let n' <- ByteArray.toBV64 arr 0
return n == n') == .ok true

end TensorLib
Loading

0 comments on commit 9e794a9

Please sign in to comment.