Skip to content

Commit

Permalink
chore: fix build, address comment from previous PR
Browse files Browse the repository at this point in the history
As I was testing testing, I forgot to fix my test.
Also responded to an unresolved comment from
#39
  • Loading branch information
seanmcl committed Feb 26, 2025
1 parent ef0fca5 commit b9ee965
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
15 changes: 14 additions & 1 deletion TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,21 @@ deriving BEq, Repr, Inhabited

namespace Name

-- Should match the NumPy name of the dtype. We use toString to generate NumPy test code.
instance : ToString Name where
toString x := ((repr x).pretty.splitOn ".").getLast!
toString
| .bool => "bool"
| int8 => "int8"
| int16 => "int16"
| int32 => "int32"
| int64 => "int64"
| uint8 => "uint8"
| uint16 => "uint16"
| uint32 => "uint32"
| uint64 => "uint64"
| float16 => "float16"
| float32 => "float32"
| float64 => "float64"

def isOneByte (x : Name) : Bool := match x with
| bool | int8 | uint8 => true
Expand Down
2 changes: 1 addition & 1 deletion TensorLib/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ private def saveNumpyArray (expr : String) : IO System.FilePath := do
return file.addExtension "npy"

private def testTensorElementBV (dtype : Dtype) : IO Bool := do
let file <- saveNumpyArray s!"np.arange(20, dtype='{dtype.name}').reshape(, 4)"
let file <- saveNumpyArray s!"np.arange(20, dtype='{dtype.name}').reshape(5, 4)"
let npy <- Npy.parseFile file
let arr <- IO.ofExcept (Tensor.ofNpy npy)
let _ <- IO.FS.removeFile file
Expand Down

0 comments on commit b9ee965

Please sign in to comment.