11module LambdaBuffers.Codegen.LamVal.Eq (deriveEqImpl ) where
22
3+ import Control.Exception qualified as Exception
34import Data.Map.Ordered qualified as OMap
45import LambdaBuffers.Codegen.LamVal (Field , QProduct , QRecord , QSum , ValueE (CaseE , FieldE , LamE , LetE , RefE ), (@) )
56import LambdaBuffers.Codegen.LamVal.Derive (deriveImpl )
@@ -23,11 +24,7 @@ eqSum qsum =
2324 r
2425 ( \ (ctorTyR, rxs) ->
2526 if fst ctorTyL == fst ctorTyR
26- then
27- foldl
28- (\ tot (lx, rx, ty) -> andE @ tot @ (eqE ty @ lx @ rx))
29- trueE
30- (zip3 lxs rxs (snd ctorTyL))
27+ then eqListHelper lxs rxs (snd ctorTyL)
3128 else falseE
3229 )
3330 )
@@ -48,11 +45,7 @@ eqProduct qprod@(_, prodTy) =
4845 LetE
4946 qprod
5047 r
51- ( \ rxs ->
52- foldl
53- (\ tot (lx, rx, ty) -> andE @ tot @ (eqE ty @ lx @ rx))
54- trueE
55- (zip3 lxs rxs prodTy)
48+ ( \ rxs -> eqListHelper lxs rxs prodTy
5649 )
5750 )
5851 )
@@ -65,16 +58,39 @@ eqRecord (qtyN, recTy) =
6558 ( \ l ->
6659 LamE
6760 ( \ r ->
68- foldl
69- (\ tot field -> andE @ tot @ eqField qtyN field l r)
70- trueE
71- (OMap. assocs recTy)
61+ let eqFieldExprs = map (\ field -> eqField qtyN field l r) $ OMap. assocs recTy
62+ in if null eqFieldExprs
63+ then trueE
64+ else
65+ foldl1
66+ (\ tot eqFieldExpr -> andE @ tot @ eqFieldExpr)
67+ eqFieldExprs
7268 )
7369 )
7470
7571eqField :: PC. QTyName -> Field -> ValueE -> ValueE -> ValueE
7672eqField qtyN (fieldName, fieldTy) l r = eqE fieldTy @ FieldE (qtyN, fieldName) l @ FieldE (qtyN, fieldName) r
7773
74+ {- | 'eqListHelper' is an internal function which equates two lists of 'ValueE'
75+ with their type pairwise.
76+
77+ Preconditions:
78+ - All input lists are the same length
79+ -}
80+ eqListHelper :: [ValueE ] -> [ValueE ] -> [LT. Ty ] -> ValueE
81+ eqListHelper lxs rxs tys =
82+ Exception. assert preconditionAssertion $
83+ let eqedLxsRxsTys = map (\ (lx, rx, ty) -> eqE ty @ lx @ rx) $ zip3 lxs rxs tys
84+ in if null eqedLxsRxsTys
85+ then trueE
86+ else foldl1 (\ tot eqExpr -> andE @ tot @ eqExpr) eqedLxsRxsTys
87+ where
88+ preconditionAssertion =
89+ let lxsLength = length lxs
90+ rxsLength = length rxs
91+ tysLength = length tys
92+ in lxsLength == rxsLength && rxsLength == tysLength
93+
7894-- | Hooks
7995deriveEqImpl :: PC. ModuleName -> PC. TyDefs -> PC. Ty -> Either P. InternalError ValueE
8096deriveEqImpl mn tydefs = deriveImpl mn tydefs eqSum eqProduct eqRecord
0 commit comments