Skip to content

Commit 7173e0c

Browse files
committed
Rust codegen: recursively checking for phantom type arguments
1 parent 8752148 commit 7173e0c

File tree

2 files changed

+86
-43
lines changed

2 files changed

+86
-43
lines changed

lambda-buffers-codegen/src/LambdaBuffers/Codegen/Rust/Print/LamVal.hs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,23 @@ printCtorCase pkgs iTyDefs (_, tyn) ctorCont ctor@(ctorN, fields) = do
6060
else return $ group $ ctorNameDoc <+> encloseSep lparen rparen comma argDocs <+> "=>" <+> group bodyDoc
6161

6262
printCaseE :: MonadPrint m => R.PkgMap -> PC.TyDefs -> LV.QSum -> LV.ValueE -> ((LV.Ctor, [LV.ValueE]) -> LV.ValueE) -> m (Doc ann)
63-
printCaseE pkgs iTyDefs (qtyN@(_, tyN), sumTy) caseVal ctorCont = do
63+
printCaseE pkgs iTyDefs (qtyN@(mn', tyN'), sumTy) caseVal ctorCont = do
64+
let mn = withInfo mn'
65+
tyN = withInfo tyN'
6466
caseValDoc <- printValueE pkgs iTyDefs caseVal
6567
(tyArgs, fieldTys) <-
6668
case Map.lookup qtyN iTyDefs of
6769
Just (PC.TyDef _ (PC.TyAbs tyArgs (PC.SumI (PC.Sum ctors _)) _) _) -> do
6870
return (toList tyArgs, TD.sumCtorTys ctors)
6971
_ -> throwInternalError "Expected a SumE but got something else (TODO(szg251): Print got)"
7072

71-
let phantomFields = TD.collectPhantomTyArgs fieldTys tyArgs
73+
let phantomFields = TD.collectPhantomTyArgs iTyDefs mn tyN fieldTys tyArgs
7274
phantomCaseDoc =
7375
if null phantomFields
7476
then mempty
7577
else
7678
let phantomCtor =
77-
R.printTyName (withInfo tyN)
79+
R.printTyName tyN
7880
<> R.doubleColon
7981
<> TD.phantomDataCtorIdent
8082
<> encloseSep lparen rparen comma ("_" <$ phantomFields)
@@ -250,15 +252,15 @@ printRecordE pkgs iTyDefs (qtyN@(mn', tyN'), _) vals = do
250252
Just (PC.TyDef _ (PC.TyAbs tyArgs (PC.RecordI (PC.Record fields _)) _) _) -> return (toList tyArgs, TD.recFieldTys fields)
251253
_ -> throwInternalError "Expected a RecordE but got something else (TODO(szg251): Print got)"
252254

253-
let phantomFields = TD.collectPhantomTyArgs fieldTys tyArgs
255+
let phantomFields = TD.collectPhantomTyArgs iTyDefs mn tyN fieldTys tyArgs
254256
phantomFieldDocs =
255257
if null phantomFields
256258
then mempty
257259
else printPhantomDataField <$> phantomFields
258260
mayBoxedFields = zip (sortOn fst vals) $ TD.isRecursive iTyDefs mn tyN <$> fieldTys
259261

260-
fieldDocs <- for mayBoxedFields
261-
$ \(((fieldN, _), val), isBoxed) ->
262+
fieldDocs <- for mayBoxedFields $
263+
\(((fieldN, _), val), isBoxed) ->
262264
let fieldNDoc = R.printFieldName (withInfo fieldN)
263265
in do
264266
valDoc <- printMaybeBoxed pkgs iTyDefs (val, isBoxed)
@@ -280,7 +282,7 @@ printProductE pkgs iTyDefs (qtyN@(mn', tyN'), _) vals = do
280282
case Map.lookup qtyN iTyDefs of
281283
Just (PC.TyDef _ (PC.TyAbs tyArgs (PC.ProductI (PC.Product fields _)) _) _) -> return (toList tyArgs, fields)
282284
_ -> throwInternalError "Expected a ProductE but got something else (TODO(szg251): Print got)"
283-
let phantomFieldDocs = R.printRsQTyName RR.phantomData <$ TD.collectPhantomTyArgs fieldTys tyArgs
285+
let phantomFieldDocs = R.printRsQTyName RR.phantomData <$ TD.collectPhantomTyArgs iTyDefs mn tyN fieldTys tyArgs
284286
mayBoxedFields = zip vals $ TD.isRecursive iTyDefs mn tyN <$> fieldTys
285287

286288
fieldDocs <- for mayBoxedFields (printMaybeBoxed pkgs iTyDefs)
@@ -339,11 +341,11 @@ printRefE pkgs ref = do
339341
| builtin
340342
== "toPlutusData"
341343
|| builtin
342-
== "fromPlutusData"
344+
== "fromPlutusData"
343345
|| builtin
344-
== "toJson"
346+
== "toJson"
345347
|| builtin
346-
== "fromJson" -> do
348+
== "fromJson" -> do
347349
lamTyDoc <- printLamTy pkgs argTy
348350
methodDoc <- R.printRsValName . R.qualifiedEntity <$> LV.importValue qvn
349351
return $ angles lamTyDoc <> R.doubleColon <> methodDoc

lambda-buffers-codegen/src/LambdaBuffers/Codegen/Rust/Print/TyDef.hs

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ module LambdaBuffers.Codegen.Rust.Print.TyDef (printTyDef, printTyTopLevel, prin
33
import Control.Lens (view)
44
import Control.Monad.Reader.Class (asks)
55
import Data.Foldable (Foldable (toList))
6+
import Data.Map (Map)
67
import Data.Map qualified as Map
78
import Data.Map.Ordered (OMap)
89
import Data.Map.Ordered qualified as OMap
10+
import Data.Maybe (fromMaybe)
911
import Data.Set (Set)
1012
import Data.Set qualified as Set
11-
import Data.Text (Text)
1213
import Data.Traversable (for)
1314
import LambdaBuffers.Codegen.Config (cfgOpaques)
1415
import LambdaBuffers.Codegen.Print (throwInternalError)
@@ -115,7 +116,10 @@ printTyBody _ tyN args (PC.OpaqueI si) = do
115116

116117
printSum :: MonadPrint m => R.PkgMap -> PC.TyName -> [PC.TyArg] -> PC.Sum -> m (Doc ann)
117118
printSum pkgs parentTyN tyArgs (PC.Sum ctors _) = do
118-
let phantomTyArgs = collectPhantomTyArgs (sumCtorTys ctors) tyArgs
119+
ci <- asks (view Print.ctxCompilerInput)
120+
let iTyDefs = indexTyDefs ci
121+
mn <- asks (view $ Print.ctxModule . #moduleName)
122+
let phantomTyArgs = collectPhantomTyArgs iTyDefs mn parentTyN (sumCtorTys ctors) tyArgs
119123
phantomCtor = if null phantomTyArgs then mempty else [printPhantomDataCtor phantomTyArgs]
120124
ctorDocs <- traverse (printCtor pkgs parentTyN) (toList ctors)
121125
if null ctors
@@ -139,7 +143,10 @@ printCtorInner pkgs parentTyN (PC.Product fields _) = do
139143

140144
printRec :: MonadPrint m => R.PkgMap -> PC.TyName -> [PC.TyArg] -> PC.Record -> m (Doc ann)
141145
printRec pkgs parentTyN tyArgs (PC.Record fields _) = do
142-
let phantomTyArgs = collectPhantomTyArgs (recFieldTys fields) tyArgs
146+
ci <- asks (view Print.ctxCompilerInput)
147+
let iTyDefs = indexTyDefs ci
148+
mn <- asks (view $ Print.ctxModule . #moduleName)
149+
let phantomTyArgs = collectPhantomTyArgs iTyDefs mn parentTyN (recFieldTys fields) tyArgs
143150
phantomFields = printPhantomDataField <$> phantomTyArgs
144151
if null fields && null phantomTyArgs
145152
then return semi
@@ -149,21 +156,55 @@ printRec pkgs parentTyN tyArgs (PC.Record fields _) = do
149156

150157
printProd :: MonadPrint m => R.PkgMap -> PC.TyName -> [PC.TyArg] -> PC.Product -> m (Doc ann)
151158
printProd pkgs parentTyN tyArgs (PC.Product fields _) = do
159+
ci <- asks (view Print.ctxCompilerInput)
160+
let iTyDefs = indexTyDefs ci
161+
mn <- asks (view $ Print.ctxModule . #moduleName)
152162
tyDocs <- for fields (printTyTopLevel pkgs parentTyN)
153-
let phantomTyArgs = collectPhantomTyArgs fields tyArgs
163+
let phantomTyArgs = collectPhantomTyArgs iTyDefs mn parentTyN fields tyArgs
154164
phantomFields = printPhantomData <$> phantomTyArgs
155165
if null fields && null phantomTyArgs
156166
then return semi
157167
else return $ encloseSep lparen rparen comma (tyDocs <> phantomFields) <> semi
158168

159-
-- | Filter out unused type arguments in order to make PhantomData fields for them
160-
collectPhantomTyArgs :: [PC.Ty] -> [PC.TyArg] -> [PC.TyArg]
161-
collectPhantomTyArgs tys tyArgs = foldr go tyArgs tys
169+
{- | Filter out unused type arguments in order to make PhantomData fields for them
170+
This is done in a recursive manner: if we encounter a type application, we resolve the type from TyDefs, and substitute
171+
all type variable with the arguments from the parent types type abstraction. We're also keeping track of all
172+
the type names already seen to avoid infinite recursions.
173+
-}
174+
collectPhantomTyArgs :: PC.TyDefs -> PC.ModuleName -> PC.TyName -> [PC.Ty] -> [PC.TyArg] -> [PC.TyArg]
175+
collectPhantomTyArgs iTyDefs ownMn parentTyN tys tyArgs = foldr (go (Set.singleton (PC.mkInfoLess parentTyN))) tyArgs tys
162176
where
163-
go :: PC.Ty -> [PC.TyArg] -> [PC.TyArg]
164-
go (PC.TyVarI (PC.TyVar (PC.VarName varName _))) tyArgs' = filter (\(PC.TyArg (PC.VarName varName' _) _ _) -> varName /= varName') tyArgs'
165-
go (PC.TyAppI (PC.TyApp _ tys' _)) tyArgs' = foldr go tyArgs' tys'
166-
go (PC.TyRefI _) tyArgs' = tyArgs'
177+
go :: Set (PC.InfoLess PC.TyName) -> PC.Ty -> [PC.TyArg] -> [PC.TyArg]
178+
go _ (PC.TyVarI (PC.TyVar (PC.VarName varName _))) tyArgs' = filter (\(PC.TyArg (PC.VarName varName' _) _ _) -> varName /= varName') tyArgs'
179+
go seenTys (PC.TyAppI (PC.TyApp tyFunc tys' _)) tyArgs' =
180+
case tyFunc of
181+
PC.TyRefI ref ->
182+
let qtyN@(_, tyN) =
183+
case ref of
184+
PC.LocalI (PC.LocalRef tyN' _) -> (mkInfoLess ownMn, mkInfoLess tyN')
185+
PC.ForeignI (PC.ForeignRef tyN' mn _) -> (mkInfoLess mn, mkInfoLess tyN')
186+
187+
resolvedChildrenTys =
188+
case Map.lookup qtyN iTyDefs of
189+
Nothing -> [] -- TODO(szg251): Gracefully failing, but this should be an error instead
190+
Just (PC.TyDef _ (PC.TyAbs omap tyBody _) _) ->
191+
let tyAbsArgs = fst <$> OMap.assocs omap
192+
resolvedArgs = Map.fromList $ zip tyAbsArgs tys'
193+
tyBodyTys = case tyBody of
194+
PC.OpaqueI _ -> []
195+
PC.SumI (PC.Sum ctors _) -> sumCtorTys ctors
196+
PC.ProductI (PC.Product fields _) -> fields
197+
PC.RecordI (PC.Record fields _) -> recFieldTys fields
198+
in resolveTyVar resolvedArgs <$> tyBodyTys
199+
in if Set.member tyN seenTys
200+
then tyArgs'
201+
else foldr (go (Set.insert tyN seenTys)) tyArgs' resolvedChildrenTys
202+
_ -> tyArgs'
203+
go _ (PC.TyRefI _) tyArgs' = tyArgs'
204+
205+
resolveTyVar :: Map (PC.InfoLess PC.VarName) PC.Ty -> PC.Ty -> PC.Ty
206+
resolveTyVar resolvedArgs ty@(PC.TyVarI (PC.TyVar varName)) = fromMaybe ty $ Map.lookup (PC.mkInfoLess varName) resolvedArgs -- TODO(szg251): Should this be an error too? Guess so..
207+
resolveTyVar _ ty = ty
167208

168209
-- | Returns Ty information of all record fields, sorted by field name
169210
recFieldTys :: OMap (PC.InfoLess PC.FieldName) PC.Field -> [PC.Ty]
@@ -241,30 +282,30 @@ printTyInner pkgs (PC.TyAppI a) = printTyApp pkgs a
241282
This is done by resolving references, and searching for reoccurances of the parent type name
242283
-}
243284
isRecursive :: PC.TyDefs -> PC.ModuleName -> PC.TyName -> PC.Ty -> Bool
244-
isRecursive iTyDefs mn (PC.TyName parentTyNameT _) = go mempty
285+
isRecursive iTyDefs ownMn parentTyName = go mempty
245286
where
246-
go :: Set Text -> PC.Ty -> Bool
287+
go :: Set (PC.InfoLess PC.TyName) -> PC.Ty -> Bool
247288
go _ (PC.TyVarI _) = False
248-
go otherTys (PC.TyAppI (PC.TyApp _ tyArgs _)) = any (go otherTys) tyArgs
289+
go otherTys (PC.TyAppI (PC.TyApp tyFunc tyArgs _)) = any (go otherTys) $ tyFunc : tyArgs
249290
go otherTys (PC.TyRefI ref) = do
250-
let (qtyN, tyN) =
291+
let qtyN@(_, tyN) =
251292
case ref of
252-
PC.LocalI (PC.LocalRef tyN' _) ->
253-
let (PC.TyName tyNameT _) = tyN'
254-
in ((mkInfoLess mn, mkInfoLess tyN'), tyNameT)
255-
PC.ForeignI (PC.ForeignRef tyN' mn' _) ->
256-
let (PC.TyName tyNameT _) = tyN'
257-
in ((mkInfoLess mn', mkInfoLess tyN'), tyNameT)
258-
259-
let childrenTys =
260-
case Map.lookup qtyN iTyDefs of
261-
Nothing -> [] -- TODO(szg251): Gracefully failing, but this should be an error instead
262-
Just (PC.TyDef _ (PC.TyAbs _ tyBody _) _) ->
263-
case tyBody of
264-
PC.OpaqueI _ -> []
265-
PC.SumI (PC.Sum ctors _) -> sumCtorTys ctors
266-
PC.ProductI (PC.Product fields _) -> fields
267-
PC.RecordI (PC.Record fields _) -> recFieldTys fields
268-
269-
(parentTyNameT == tyN)
293+
PC.LocalI (PC.LocalRef tyN' _) -> (mkInfoLess ownMn, mkInfoLess tyN')
294+
PC.ForeignI (PC.ForeignRef tyN' mn _) -> (mkInfoLess mn, mkInfoLess tyN')
295+
296+
let childrenTys = findChildren iTyDefs qtyN
297+
298+
(PC.mkInfoLess parentTyName == tyN)
270299
|| (not (Set.member tyN otherTys) && any (go (Set.insert tyN otherTys)) childrenTys)
300+
301+
-- | Resolve a qualified type name and return all it's children types
302+
findChildren :: PC.TyDefs -> PC.QTyName -> [PC.Ty]
303+
findChildren iTyDefs qtyN =
304+
case Map.lookup qtyN iTyDefs of
305+
Nothing -> [] -- TODO(szg251): Gracefully failing, but this should be an error instead
306+
Just (PC.TyDef _ (PC.TyAbs _ tyBody _) _) ->
307+
case tyBody of
308+
PC.OpaqueI _ -> []
309+
PC.SumI (PC.Sum ctors _) -> sumCtorTys ctors
310+
PC.ProductI (PC.Product fields _) -> fields
311+
PC.RecordI (PC.Record fields _) -> recFieldTys fields

0 commit comments

Comments
 (0)