Skip to content

Commit

Permalink
New formatter.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Oct 23, 2023
1 parent 6096315 commit 45b3403
Show file tree
Hide file tree
Showing 28 changed files with 396 additions and 309 deletions.
3 changes: 2 additions & 1 deletion src/Futhark/CLI/REPL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ typeCommand = genTypeCommand parseExp T.checkExp $ \(ps, e) ->
then
annotate italicized $
"\n\nPolymorphic in"
<+> mconcat (intersperse " " $ map pretty ps) <> "."
<+> mconcat (intersperse " " $ map pretty ps)
<> "."
else mempty

mtypeCommand :: Command
Expand Down
26 changes: 13 additions & 13 deletions src/Futhark/CodeGen/Backends/GenericPython/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,26 @@ instance Pretty PyStmt where
"if"
<+> pretty cond
<> ":"
</> indent 2 "pass"
</> indent 2 "pass"
pretty (If cond [] fbranch) =
"if"
<+> pretty cond
<> ":"
</> indent 2 "pass"
</> "else:"
</> indent 2 (stack $ map pretty fbranch)
</> indent 2 "pass"
</> "else:"
</> indent 2 (stack $ map pretty fbranch)
pretty (If cond tbranch []) =
"if"
<+> pretty cond
<> ":"
</> indent 2 (stack $ map pretty tbranch)
</> indent 2 (stack $ map pretty tbranch)
pretty (If cond tbranch fbranch) =
"if"
<+> pretty cond
<> ":"
</> indent 2 (stack $ map pretty tbranch)
</> "else:"
</> indent 2 (stack $ map pretty fbranch)
</> indent 2 (stack $ map pretty tbranch)
</> "else:"
</> indent 2 (stack $ map pretty fbranch)
pretty (Try pystms pyexcepts) =
"try:"
</> indent 2 (stack $ map pretty pystms)
Expand All @@ -155,19 +155,19 @@ instance Pretty PyStmt where
"while"
<+> pretty cond
<> ":"
</> indent 2 (stack $ map pretty body)
</> indent 2 (stack $ map pretty body)
pretty (For i what body) =
"for"
<+> pretty i
<+> "in"
<+> pretty what
<> ":"
</> indent 2 (stack $ map pretty body)
</> indent 2 (stack $ map pretty body)
pretty (With what body) =
"with"
<+> pretty what
<> ":"
</> indent 2 (stack $ map pretty body)
</> indent 2 (stack $ map pretty body)
pretty (Assign e1 e2) = pretty e1 <+> "=" <+> pretty e2
pretty (AssignOp op e1 e2) = pretty e1 <+> pretty (op ++ "=") <+> pretty e2
pretty (Comment s body) = "#" <> pretty s </> stack (map pretty body)
Expand All @@ -190,14 +190,14 @@ instance Pretty PyFunDef where
<+> pretty fname
<> parens (commasep $ map pretty params)
<> ":"
</> indent 2 (stack (map pretty body))
</> indent 2 (stack (map pretty body))

instance Pretty PyClassDef where
pretty (Class cname body) =
"class"
<+> pretty cname
<> ":"
</> indent 2 (stack (map pretty body))
</> indent 2 (stack (map pretty body))

instance Pretty PyExcept where
pretty (Catch pyexp stms) =
Expand Down
34 changes: 25 additions & 9 deletions src/Futhark/CodeGen/ImpCode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -574,17 +574,29 @@ instance (Pretty op) => Pretty (Code op) where
pretty (Free name space) =
"free" <> parens (pretty name) <> pretty space
pretty (Write name i bt space vol val) =
pretty name <> langle <> vol' <> pretty bt <> pretty space <> rangle <> brackets (pretty i)
<+> "<-"
<+> pretty val
pretty name
<> langle
<> vol'
<> pretty bt
<> pretty space
<> rangle
<> brackets (pretty i)
<+> "<-"
<+> pretty val
where
vol' = case vol of
Volatile -> "volatile "
Nonvolatile -> mempty
pretty (Read name v is bt space vol) =
pretty name
<+> "<-"
<+> pretty v <> langle <> vol' <> pretty bt <> pretty space <> rangle <> brackets (pretty is)
<+> pretty v
<> langle
<> vol'
<> pretty bt
<> pretty space
<> rangle
<> brackets (pretty is)
where
vol' = case vol of
Volatile -> "volatile "
Expand All @@ -602,14 +614,17 @@ instance (Pretty op) => Pretty (Code op) where
<> (parens . align)
( foldMap (brackets . pretty) shape
<> ","
</> p dst dstspace dstoffset dststrides
</> p dst dstspace dstoffset dststrides
<> ","
</> p src srcspace srcoffset srcstrides
</> p src srcspace srcoffset srcstrides
)
where
p mem space offset strides =
pretty mem <> pretty space <> "+" <> pretty offset
<+> foldMap (brackets . pretty) strides
pretty mem
<> pretty space
<> "+"
<> pretty offset
<+> foldMap (brackets . pretty) strides
pretty (If cond tbranch fbranch) =
"if"
<+> pretty cond
Expand All @@ -626,7 +641,8 @@ instance (Pretty op) => Pretty (Code op) where
"call"
<+> commasep (map pretty dests)
<+> "<-"
<+> pretty fname <> parens (commasep $ map pretty args)
<+> pretty fname
<> parens (commasep $ map pretty args)
pretty (Comment s code) =
"--" <+> pretty s </> pretty code
pretty (DebugPrint desc (Just e)) =
Expand Down
63 changes: 34 additions & 29 deletions src/Futhark/CodeGen/ImpCode/GPU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,17 @@ instance Pretty HostOp where
pretty (GetSize dest key size_class) =
pretty dest
<+> "<-"
<+> "get_size" <> parens (commasep [pretty key, pretty size_class])
<+> "get_size"
<> parens (commasep [pretty key, pretty size_class])
pretty (GetSizeMax dest size_class) =
pretty dest <+> "<-" <+> "get_size_max" <> parens (pretty size_class)
pretty (CmpSizeLe dest name size_class x) =
pretty dest
<+> "<-"
<+> "get_size" <> parens (commasep [pretty name, pretty size_class])
<+> "<"
<+> pretty x
<+> "get_size"
<> parens (commasep [pretty name, pretty size_class])
<+> "<"
<+> pretty x
pretty (CallKernel c) =
pretty c

Expand Down Expand Up @@ -211,15 +213,18 @@ instance Pretty KernelOp where
pretty (GetGroupId dest i) =
pretty dest
<+> "<-"
<+> "get_group_id" <> parens (pretty i)
<+> "get_group_id"
<> parens (pretty i)
pretty (GetLocalId dest i) =
pretty dest
<+> "<-"
<+> "get_local_id" <> parens (pretty i)
<+> "get_local_id"
<> parens (pretty i)
pretty (GetLocalSize dest i) =
pretty dest
<+> "<-"
<+> "get_local_size" <> parens (pretty i)
<+> "get_local_size"
<> parens (pretty i)
pretty (GetLockstepWidth dest) =
pretty dest
<+> "<-"
Expand All @@ -242,68 +247,68 @@ instance Pretty KernelOp where
pretty old
<+> "<-"
<+> "atomic_add_"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicFAdd t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_fadd_"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicSMax t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_smax"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicSMin t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_smin"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicUMax t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_umax"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicUMin t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_umin"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicAnd t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_and"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicOr t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_or"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicXor t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_xor"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
pretty (Atomic _ (AtomicCmpXchg t old arr ind x y)) =
pretty old
<+> "<-"
<+> "atomic_cmp_xchg"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x, pretty y])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x, pretty y])
pretty (Atomic _ (AtomicXchg t old arr ind x)) =
pretty old
<+> "<-"
<+> "atomic_xchg"
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])
<> pretty t
<> parens (commasep [pretty arr <> brackets (pretty ind), pretty x])

instance FreeIn KernelOp where
freeIn' (Atomic _ op) = freeIn' op
Expand Down
21 changes: 14 additions & 7 deletions src/Futhark/CodeGen/ImpGen/GPU/SegHist.hs
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,10 @@ histKernelGlobalPass map_pes num_groups group_size space slugs kbody histograms
dest_shape' = map pe64 $ shapeDims dest_shape
flat_bucket = flattenIndex dest_shape' bucket'
bucket_in_bounds =
chk_beg .<=. flat_bucket
.&&. flat_bucket .<. (chk_beg + hist_H_chk)
chk_beg
.<=. flat_bucket
.&&. flat_bucket
.<. (chk_beg + hist_H_chk)
.&&. inBounds (Slice (map DimFix bucket')) dest_shape'
vs_params = takeLast (length vs') $ lambdaParams lam

Expand Down Expand Up @@ -760,8 +762,10 @@ histKernelLocalPass
flat_bucket = flattenIndex dest_shape' bucket'
bucket_in_bounds =
inBounds (Slice (map DimFix bucket')) dest_shape'
.&&. chk_beg .<=. flat_bucket
.&&. flat_bucket .<. (chk_beg + tvExp hist_H_chk)
.&&. chk_beg
.<=. flat_bucket
.&&. flat_bucket
.<. (chk_beg + tvExp hist_H_chk)
bucket_is =
[sExt64 thread_local_subhisto_i, flat_bucket - chk_beg]
vs_params = takeLast (length vs') $ lambdaParams lam
Expand Down Expand Up @@ -1025,11 +1029,14 @@ localMemoryCase map_pes hist_T space hist_H hist_el_size hist_N _ slugs kbody =
-- asymptotically efficient. This mostly matters for the segmented
-- case.
let pick_local =
hist_Nin .>=. hist_H
hist_Nin
.>=. hist_H
.&&. (local_mem_needed .<=. tvExp hist_L)
.&&. (hist_S .<=. max_S)
.&&. hist_C .<=. hist_B
.&&. tvExp hist_M .>. 0
.&&. hist_C
.<=. hist_B
.&&. tvExp hist_M
.>. 0

run = do
emit $ Imp.DebugPrint "## Using local memory" Nothing
Expand Down
10 changes: 5 additions & 5 deletions src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ smallSegmentsReduction (Pat segred_pes) num_groups group_size space reds body =
.>. 0
.&&. isActive (init $ zip gtids dims)
.&&. ltid
.<. segment_size
* segments_per_group
.<. segment_size
* segments_per_group
)
in_bounds
out_of_bounds
Expand All @@ -345,8 +345,8 @@ smallSegmentsReduction (Pat segred_pes) num_groups group_size space reds body =
( sExt64 group_id'
* segments_per_group
+ sExt64 ltid
.<. num_segments
.&&. ltid
.<. num_segments
.&&. ltid
.<. segments_per_group
)
$ forM_ (zip segred_pes (concat reds_arrs))
Expand Down Expand Up @@ -603,7 +603,7 @@ computeThreadChunkSize Noncommutative _ thread_index elements_per_thread num_ele
is_last_thread =
Imp.unCount num_elements
.<. (thread_index + 1)
* Imp.unCount elements_per_thread
* Imp.unCount elements_per_thread

reductionStageZero ::
KernelConstants ->
Expand Down
14 changes: 9 additions & 5 deletions src/Futhark/IR/GPU/Op.hs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ instance PP.Pretty SegVirt where

instance PP.Pretty KernelGrid where
pretty (KernelGrid num_groups group_size) =
"groups=" <> pretty num_groups <> PP.semi
<+> "groupsize=" <> pretty group_size
"groups="
<> pretty num_groups
<> PP.semi
<+> "groupsize="
<> pretty group_size

instance PP.Pretty SegLevel where
pretty (SegThread virt grid) =
Expand Down Expand Up @@ -219,9 +222,10 @@ instance PP.Pretty SizeOp where
pretty (GetSizeMax size_class) =
"get_size_max" <> parens (commasep [pretty size_class])
pretty (CmpSizeLe name size_class x) =
"cmp_size" <> parens (commasep [pretty name, pretty size_class])
<+> "<="
<+> pretty x
"cmp_size"
<> parens (commasep [pretty name, pretty size_class])
<+> "<="
<+> pretty x
pretty (CalcNumGroups w max_num_groups group_size) =
"calc_num_groups" <> parens (commasep [pretty w, pretty max_num_groups, pretty group_size])

Expand Down
Loading

0 comments on commit 45b3403

Please sign in to comment.