Skip to content

Commit

Permalink
Mark index space transformation functions as inline.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Nov 23, 2023
1 parent 90f5a1a commit 12d870c
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions prelude/array.fut
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,56 @@ def null [n] 't (_: [n]t) = n == 0
-- | The first element of the array.
--
-- **Complexity:** O(1).
#[inline]
def head [n] 't (x: [n]t) = x[0]

-- | The last element of the array.
--
-- **Complexity:** O(1).
#[inline]
def last [n] 't (x: [n]t) = x[n-1]

-- | Everything but the first element of the array.
--
-- **Complexity:** O(1).
#[inline]
def tail [n] 't (x: [n]t): [n-1]t = x[1:]

-- | Everything but the last element of the array.
--
-- **Complexity:** O(1).
#[inline]
def init [n] 't (x: [n]t): [n-1]t = x[0:n-1]

-- | Take some number of elements from the head of the array.
--
-- **Complexity:** O(1).
#[inline]
def take [n] 't (i: i64) (x: [n]t): [i]t = x[0:i]

-- | Remove some number of elements from the head of the array.
--
-- **Complexity:** O(1).
#[inline]
def drop [n] 't (i: i64) (x: [n]t): [n-i]t = x[i:]

-- | Statically change the size of an array. Fail at runtime if the
-- imposed size does not match the actual size. Essentially syntactic
-- sugar for a size coercion.
#[inline]
def sized [m] 't (n: i64) (xs: [m]t) : [n]t = xs :> [n]t

-- | Split an array at a given position.
--
-- **Complexity:** O(1).
#[inline]
def split [n][m] 't (xs: [n+m]t): ([n]t, [m]t) =
(xs[0:n], xs[n:n+m] :> [m]t)

-- | Return the elements of the array in reverse order.
--
-- **Complexity:** O(1).
#[inline]
def reverse [n] 't (x: [n]t): [n]t = x[::-1]

-- | Concatenate two arrays. Warning: never try to perform a reduction
Expand All @@ -67,9 +76,11 @@ def reverse [n] 't (x: [n]t): [n]t = x[::-1]
-- **Work:** O(n).
--
-- **Span:** O(1).
#[inline]
def (++) [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = intrinsics.concat xs ys

-- | An old-fashioned way of saying `++`.
#[inline]
def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys

-- | Construct an array of consecutive integers of the given length,
Expand All @@ -78,6 +89,7 @@ def concat [n] [m] 't (xs: [n]t) (ys: [m]t): *[n+m]t = xs ++ ys
-- **Work:** O(n).
--
-- **Span:** O(1).
#[inline]
def iota (n: i64): *[n]i64 =
0..1..<n

Expand All @@ -87,6 +99,7 @@ def iota (n: i64): *[n]i64 =
-- **Work:** O(n).
--
-- **Span:** O(1).
#[inline]
def indices [n] 't (_: [n]t) : *[n]i64 =
iota n

Expand All @@ -101,6 +114,7 @@ def indices [n] 't (_: [n]t) : *[n]i64 =
--
-- Note: In most cases, `rotate` will be fused with subsequent
-- operations such as `map`, in which case it is free.
#[inline]
def rotate [n] 't (r: i64) (a: [n]t) =
map (\i -> #[unsafe] a[(i+r)%n]) (iota n)

Expand All @@ -110,6 +124,7 @@ def rotate [n] 't (r: i64) (a: [n]t) =
-- **Work:** O(n).
--
-- **Span:** O(1).
#[inline]
def replicate 't (n: i64) (x: t): *[n]t =
map (const x) (iota n)

Expand All @@ -118,40 +133,48 @@ def replicate 't (n: i64) (x: t): *[n]t =
-- **Work:** O(n).
--
-- **Span:** O(1).
#[inline]
def copy 't (a: t): *t =
([a])[0]

-- | Combines the outer two dimensions of an array.
--
-- **Complexity:** O(1).
#[inline]
def flatten [n][m] 't (xs: [n][m]t): [n*m]t =
intrinsics.flatten xs

-- | Like `flatten`, but on the outer three dimensions of an array.
#[inline]
def flatten_3d [n][m][l] 't (xs: [n][m][l]t): [n*m*l]t =
flatten (flatten xs)

-- | Like `flatten`, but on the outer four dimensions of an array.
#[inline]
def flatten_4d [n][m][l][k] 't (xs: [n][m][l][k]t): [n*m*l*k]t =
flatten (flatten_3d xs)

-- | Splits the outer dimension of an array in two.
--
-- **Complexity:** O(1).
#[inline]
def unflatten 't [n][m] (xs: [n*m]t): [n][m]t =
intrinsics.unflatten n m xs

-- | Like `unflatten`, but produces three dimensions.
#[inline]
def unflatten_3d 't [n][m][l] (xs: [n*m*l]t): [n][m][l]t =
unflatten (unflatten xs)

-- | Like `unflatten`, but produces four dimensions.
#[inline]
def unflatten_4d 't [n][m][l][k] (xs: [n*m*l*k]t): [n][m][l][k]t =
unflatten (unflatten_3d xs)

-- | Transpose an array.
--
-- **Complexity:** O(1).
#[inline]
def transpose [n] [m] 't (a: [n][m]t): [m][n]t =
intrinsics.transpose a

Expand Down

0 comments on commit 12d870c

Please sign in to comment.