Skip to content

Commit

Permalink
Factor this into separate function.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Dec 6, 2023
1 parent 6b08f3c commit 4a9da07
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
13 changes: 7 additions & 6 deletions src/Futhark/CodeGen/Backends/GPU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ genKernelFunction kernel_name safety arg_params arg_set = do
([C.cinit|&ctx->global_failure_args|], [C.cinit|sizeof(ctx->global_failure_args)|])
]

getParamByKey :: Name -> C.Exp
getParamByKey key = [C.cexp|*ctx->tuning_params.$id:key|]

kernelConstToExp :: KernelConst -> C.Exp
kernelConstToExp (SizeConst key) =
[C.cexp|*ctx->tuning_params.$id:key|]
getParamByKey key
kernelConstToExp (SizeMaxConst size_class) =
[C.cexp|ctx->$id:field|]
where
Expand Down Expand Up @@ -133,13 +136,11 @@ genLaunchKernel safety kernel_name local_memory args num_groups group_size = do
)

callKernel :: GC.OpCompiler OpenCL ()
callKernel (GetSize v key) = do
let e = kernelConstToExp $ SizeConst key
GC.stm [C.cstm|$id:v = $exp:e;|]
callKernel (GetSize v key) =
GC.stm [C.cstm|$id:v = $exp:(getParamByKey key);|]
callKernel (CmpSizeLe v key x) = do
let e = kernelConstToExp $ SizeConst key
x' <- GC.compileExp x
GC.stm [C.cstm|$id:v = $exp:e <= $exp:x';|]
GC.stm [C.cstm|$id:v = $exp:(getParamByKey key) <= $exp:x';|]
-- Output size information if logging is enabled. The autotuner
-- depends on the format of this output, so use caution if changing
-- it.
Expand Down
11 changes: 6 additions & 5 deletions src/Futhark/CodeGen/Backends/PyOpenCL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,12 @@ compileProg mode class_name prog = do
asLong :: PyExp -> PyExp
asLong x = simpleCall "np.int64" [x]

getParamByKey :: Name -> PyExp
getParamByKey key = Index (Var "self.sizes") (IdxExp $ String $ prettyText key)

kernelConstToExp :: Imp.KernelConst -> PyExp
kernelConstToExp (Imp.SizeConst key) =
Index (Var "self.sizes") (IdxExp $ String $ prettyText key)
getParamByKey key
kernelConstToExp (Imp.SizeMaxConst size_class) =
Var $ "self.max_" <> prettyString size_class

Expand All @@ -215,13 +218,11 @@ compileGroupDim (Right kc) = pure $ kernelConstToExp kc
callKernel :: OpCompiler Imp.OpenCL ()
callKernel (Imp.GetSize v key) = do
v' <- compileVar v
stm $ Assign v' $ kernelConstToExp $ Imp.SizeConst key
stm $ Assign v' $ getParamByKey key
callKernel (Imp.CmpSizeLe v key x) = do
v' <- compileVar v
x' <- compileExp x
stm $
Assign v' $
BinOp "<=" (kernelConstToExp (Imp.SizeConst key)) x'
stm $ Assign v' $ BinOp "<=" (getParamByKey key) x'
callKernel (Imp.GetSizeMax v size_class) = do
v' <- compileVar v
stm $ Assign v' $ kernelConstToExp $ Imp.SizeMaxConst size_class
Expand Down

0 comments on commit 4a9da07

Please sign in to comment.