Skip to content

Commit

Permalink
Remove impossible patterns no longer needlessly required by GHC (HEAD)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 2, 2025
1 parent 850e8eb commit f2789c0
Show file tree
Hide file tree
Showing 9 changed files with 1 addition and 25 deletions.
3 changes: 0 additions & 3 deletions example/MnistRnnRanked2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ rnnMnistLayerR s x (wX, wS, b) = case rshape s of
let y = wX `rmatmul2` x + wS `rmatmul2` s
+ rtr (rreplicate batch_size b)
in tanh y
_ -> error "rnnMnistLayerR: wrong shape of the state"

rnnMnistTwoR
:: (ADReady target, GoodScalar r, Numeric r, Differentiable r)
Expand All @@ -95,7 +94,6 @@ rnnMnistTwoR s' x ((wX, wS, b), (wX2, wS2, b2)) = case rshape s' of
vec2 = rnnMnistLayerR s2 vec1 (wX2, wS2, b2)
in rappend vec1 vec2
in (rslice out_width out_width s3, s3)
_ -> error "rnnMnistTwoR: wrong shape of the state"

rnnMnistZeroR
:: (ADReady target, GoodScalar r, Numeric r, Differentiable r)
Expand All @@ -110,7 +108,6 @@ rnnMnistZeroR batch_size xs
(out, _s) = zeroStateR sh (unrollLastR rnnMnistTwoR) xs
((wX, wS, b), (wX2, wS2, b2))
in w3 `rmatmul2` out + rtr (rreplicate batch_size b3)
_ -> error "rnnMnistZeroR: wrong shape"

rnnMnistLossFusedR
:: ( ADReady target, ADReady (PrimalOf target), GoodScalar r
Expand Down
1 change: 0 additions & 1 deletion src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,6 @@ astIndexBuild snat@SNat stk u i = case stk of
FTKX shBuild' _->
withKnownShX sh' $
withCastXS shBuild' $ \(shBuild :: ShS shBuild) -> case shBuild of
ZSS -> error "astIndexBuild: impossible empty shape"
(:$$) _ rest ->
withKnownShS rest $
astFromS (knownSTK @y)
Expand Down
1 change: 0 additions & 1 deletion src/HordeAd/Core/DeltaEval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,6 @@ evalRevSame !s !c = \case
FTKR (len :$: _) _ -> len
s2 = evalRevSame s (rslice 0 k cShared) d
in evalRevSame s2 (rslice k (n - k) cShared) e
ZSR -> error "evalRevSame: impossible pattern needlessly required"
DeltaSliceR i n d -> case tftk (knownSTK @y) c of
FTKR (n' :$: rest) x ->
assert (n' == n `blame` (n', n)) $
Expand Down
3 changes: 0 additions & 3 deletions src/HordeAd/Core/Ops.hs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ class ( Num (IntOf target)
rsize = shrSize . rshape
rlength :: KnownSTK r => target (TKR2 (1 + n) r) -> Int
rlength v = case rshape v of
ZSR -> error "rlength: impossible pattern needlessly required"
k :$: _ -> k

rconcrete :: (KnownSTK r, KnownNat n)
Expand Down Expand Up @@ -492,7 +491,6 @@ class ( Num (IntOf target)
_ :$: width2 :$: ZSR ->
rsum (rtranspose [2,1,0] (rreplicate width2 m1)
* rtranspose [1,0] (rreplicate (rlength m1) m2))
_ -> error "rmatmul2: impossible pattern needlessly required"
rscaleByScalar :: (GoodScalar r, KnownNat n)
=> target (TKR 0 r) -> target (TKR n r) -> target (TKR n r)
rscaleByScalar s v = v * rreplicate0N (rshape v) s
Expand Down Expand Up @@ -574,7 +572,6 @@ class ( Num (IntOf target)
=> target (TKR2 (1 + n) r)
-> Maybe (target (TKR2 n r), target (TKR2 (1 + n) r))
runcons v = case rshape v of
ZSR -> Nothing
len :$: _ -> Just (v ! [0], rslice 1 (len - 1) v)
rreverse :: (KnownSTK r, KnownNat n)
=> target (TKR2 (1 + n) r) -> target (TKR2 (1 + n) r)
Expand Down
12 changes: 0 additions & 12 deletions src/HordeAd/Core/OpsAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
astFromS @(TKS (Init sh) r2) (knownSTK @(TKX (Init sh') r2))
. fromPrimal . AstMinIndexS @rest @n
. primalPart . astSFromX @sh @sh' $ a
ZSS -> error "xminIndex: impossible shape"
xmaxIndex @_ @r2 a = case ftkAst a of
FTKX @sh' sh' _ ->
withKnownShX (ssxFromShape sh') $
Expand All @@ -554,7 +553,6 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
astFromS @(TKS (Init sh) r2) (knownSTK @(TKX (Init sh') r2))
. fromPrimal . AstMaxIndexS @rest @n
. primalPart . astSFromX @sh @sh' $ a
ZSS -> error "xmaxIndex: impossible shape"
xiota @n @r = astFromS (knownSTK @(TKX '[Just n] r))
$ fromPrimal $ AstIotaS @n @r
xappend @r @sh u v = case ftkAst u of
Expand All @@ -572,8 +570,6 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
(astSFromX @shv @shv' v)
_ -> error $ "xappend: shapes don't match: "
++ show (restu, restv)
ZSS -> error "xappend: impossible shape"
ZSS -> error "xappend: impossible shape"
xslice @r @i @n @k @sh2 Proxy Proxy a = case ftkAst a of
FTKX @sh' @x sh'@(_ :$% _) _ ->
withCastXS sh' $ \(sh :: ShS sh) -> case sh of
Expand All @@ -587,7 +583,6 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
_ -> error $ "xslice: argument tensor too narrow: "
++ show ( valueOf @i :: Int, valueOf @n :: Int
, valueOf @k :: Int, sNatValue msnat )
ZSS -> error "xslice: impossible shape"
xreverse a = case ftkAst a of
FTKX @sh' @x sh' _ ->
withKnownShX (ssxFromShape sh') $
Expand All @@ -596,7 +591,6 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
withKnownShS rest $
astFromS @(TKS2 sh x) (knownSTK @(TKX2 sh' x))
. astReverseS . astSFromX @sh @sh' $ a
ZSS -> error "xreverse: impossible shape"
xtranspose @perm perm a = case ftkAst a of
FTKX @sh' @x sh' _ -> case shxPermutePrefix perm sh' of
(sh2' :: IShX sh2') ->
Expand Down Expand Up @@ -1207,7 +1201,6 @@ instance AstSpan s => BaseTensor (AstRaw s) where
AstFromS @(TKS (Init sh) r2) (knownSTK @(TKX (Init sh') r2))
. fromPrimal . AstMinIndexS @rest @n
. primalPart . AstSFromX @sh @sh' $ a
ZSS -> error "xminIndex: impossible shape"
xmaxIndex @_ @r2 (AstRaw a) = AstRaw $ case ftkAst a of
FTKX @sh' sh' _ ->
withKnownShX (ssxFromShape sh') $
Expand All @@ -1220,7 +1213,6 @@ instance AstSpan s => BaseTensor (AstRaw s) where
AstFromS @(TKS (Init sh) r2) (knownSTK @(TKX (Init sh') r2))
. fromPrimal . AstMaxIndexS @rest @n
. primalPart . AstSFromX @sh @sh' $ a
ZSS -> error "xmaxIndex: impossible shape"
xiota @n @r = AstRaw $ AstFromS (knownSTK @(TKX '[Just n] r))
$ fromPrimal $ AstIotaS @n @r
xappend @r @sh (AstRaw u) (AstRaw v) = AstRaw $ case ftkAst u of
Expand All @@ -1238,8 +1230,6 @@ instance AstSpan s => BaseTensor (AstRaw s) where
(AstSFromX @shv @shv' v)
_ -> error $ "xappend: shapes don't match: "
++ show (restu, restv)
ZSS -> error "xappend: impossible shape"
ZSS -> error "xappend: impossible shape"
xslice @r @i @n @k @sh2 Proxy Proxy (AstRaw a) = AstRaw $ case ftkAst a of
FTKX @sh' @x sh'@(_ :$% _) _ ->
withCastXS sh' $ \(sh :: ShS sh) -> case sh of
Expand All @@ -1253,7 +1243,6 @@ instance AstSpan s => BaseTensor (AstRaw s) where
_ -> error $ "xslice: argument tensor too narrow: "
++ show ( valueOf @i :: Int, valueOf @n :: Int
, valueOf @k :: Int, sNatValue msnat )
ZSS -> error "xslice: impossible shape"
xreverse (AstRaw a) = AstRaw $ case ftkAst a of
FTKX @sh' @x sh' _ ->
withKnownShX (ssxFromShape sh') $
Expand All @@ -1262,7 +1251,6 @@ instance AstSpan s => BaseTensor (AstRaw s) where
withKnownShS rest $
AstFromS @(TKS2 sh x) (knownSTK @(TKX2 sh' x))
. AstReverseS . AstSFromX @sh @sh' $ a
ZSS -> error "xreverse: impossible shape"
xtranspose @perm perm (AstRaw a) = AstRaw $ case ftkAst a of
FTKX @sh' @x sh' _ -> case shxPermutePrefix perm sh' of
(sh2' :: IShX sh2') ->
Expand Down
2 changes: 0 additions & 2 deletions src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ toLinearIdx fromInt = \sh idx -> go sh idx (fromInt 0)
go :: ShR (m1 + n) Int -> IxR m1 j -> j -> j
go sh ZIR tensidx = fromInt (shrSize sh) * tensidx
go (n :$: sh) (i :.: idx) tensidx = go sh idx (fromInt n * tensidx + i)
go _ _ _ = error "toLinearIdx: impossible pattern needlessly required"

-- | Given a linear index into the buffer, get the corresponding
-- multidimensional index. Here we require an index pointing at a scalar.
Expand Down Expand Up @@ -744,7 +743,6 @@ shCastSX ((:!%) @_ @restx (Nested.SUnknown ()) restx)
((:$$) @_ @rest snat2 rest) =
gcastWith (unsafeCoerceRefl :: Rank restx :~: Rank rest) $ -- why!
Nested.SUnknown (sNatValue snat2) :$% shCastSX restx rest
shCastSX _ _ = error "shCastSX: shapes don't match"

-- TODO; make more typed, ensure ranks match, use singletons instead of constraints,
-- give better names and do the same for ListS, etc.
Expand Down
1 change: 0 additions & 1 deletion test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,6 @@ fooD :: forall r. r ~ Double
fooD (x ::: y ::: z ::: ZR) =
let w = x * sin y
in atan2F z w + z * w
fooD _ = error "wrong number of arguments"

testFooD :: Assertion
testFooD =
Expand Down
2 changes: 1 addition & 1 deletion test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ testCNNOPP4 = do
afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
afcnn2T = maxPool2dUnpadded4 $ conv2dUnpadded4 blackGlyph
printAstPretty IM.empty afcnn2T
@?= "rfromS (sreplicate (sreplicate (let w41 = sgather (sfromVector (fromList [let w17 = stranspose (sreplicate (sreplicate (sreplicate (sreplicate (treplicate 1 1 + siota))))) ; w16 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (treplicate 1 2 * siota)) + sreplicate siota)))) ; w8 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (treplicate 1 2 * siota)) + sreplicate siota)))) in sgather (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [7.0,0.0])) (\\[i54, i47, i36, i31, i30, i25] -> [ifF ((0 <=. kfromS (w17 !$ [i54, i47, i36, i30, i25]) &&* 1 >. kfromS (w17 !$ [i54, i47, i36, i30, i25])) &&* ((0 <=. kfromS (w16 !$ [i54, i47, i36, i30, i25]) &&* 2 >. kfromS (w16 !$ [i54, i47, i36, i30, i25])) &&* (0 <=. kfromS (w8 !$ [i54, i47, i36, i30, i25]) &&* 2 >. kfromS (w8 !$ [i54, i47, i36, i30, i25])))) 0 1]), sreplicate (sreplicate (sreplicate (sgather (sreplicate (sreplicate (sscalar 0.0))) (\\[i31, i26, i22] -> [i26, i22]))))])) (\\[i50, i43, i35, i32, i33, i34] -> [ifF ((0 <=. 1 + i35 &&* 1 >. 1 + i35) &&* ((0 <=. 1 + i32 &&* 1 >. 1 + i32) &&* ((0 <=. 2 * i50 + i33 &&* 2 >. 2 * i50 + i33) &&* (0 <=. 2 * i43 + i34 &&* 2 >. 2 * i43 + i34)))) 0 1, i50, i43, i35, i32, i33, i34]) in sgather w41 (\\[i49, i42] -> [i49, i42, 0, 0, 0, 0]))))"
@?= "rfromS (sreplicate (sreplicate (let w41 = sgather (sfromVector (fromList [let w21 = stranspose (sreplicate (sreplicate (sreplicate (sreplicate (treplicate 1 1 + siota))))) ; w20 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (treplicate 1 2 * siota)) + sreplicate siota)))) ; w12 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (treplicate 1 2 * siota)) + sreplicate siota)))) in sgather (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [7.0,0.0])) (\\[i54, i47, i36, i31, i30, i25] -> [ifF ((0 <=. kfromS (w21 !$ [i54, i47, i36, i30, i25]) &&* 1 >. kfromS (w21 !$ [i54, i47, i36, i30, i25])) &&* ((0 <=. kfromS (w20 !$ [i54, i47, i36, i30, i25]) &&* 2 >. kfromS (w20 !$ [i54, i47, i36, i30, i25])) &&* (0 <=. kfromS (w12 !$ [i54, i47, i36, i30, i25]) &&* 2 >. kfromS (w12 !$ [i54, i47, i36, i30, i25])))) 0 1]), sreplicate (sreplicate (sreplicate (sgather (sreplicate (sreplicate (sscalar 0.0))) (\\[i31, i26, i22] -> [i26, i22]))))])) (\\[i50, i43, i35, i32, i33, i34] -> [ifF ((0 <=. 1 + i35 &&* 1 >. 1 + i35) &&* ((0 <=. 1 + i32 &&* 1 >. 1 + i32) &&* ((0 <=. 2 * i50 + i33 &&* 2 >. 2 * i50 + i33) &&* (0 <=. 2 * i43 + i34 &&* 2 >. 2 * i43 + i34)))) 0 1, i50, i43, i35, i32, i33, i34]) in sgather w41 (\\[i49, i42] -> [i49, i42, 0, 0, 0, 0]))))"
printAstPretty IM.empty (simplifyInlineContract afcnn2T)
@?= "rfromS (sreplicate (sreplicate (sreplicate (sreplicate (sscalar 0.0)))))"

Expand Down
1 change: 0 additions & 1 deletion test/simplified/TestHighRankSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ fooD :: forall r n. (RealFloatF (ADVal RepN (TKR n r)))
fooD (x ::: y ::: z ::: ZR) =
let w = x * sin y
in atan2F z w + z * w
fooD _ = error "wrong number of arguments"

testFooD :: Assertion
testFooD =
Expand Down

0 comments on commit f2789c0

Please sign in to comment.