@@ -3,12 +3,13 @@ module LambdaBuffers.Codegen.Rust.Print.TyDef (printTyDef, printTyTopLevel, prin
33import Control.Lens (view )
44import Control.Monad.Reader.Class (asks )
55import Data.Foldable (Foldable (toList ))
6+ import Data.Map (Map )
67import Data.Map qualified as Map
78import Data.Map.Ordered (OMap )
89import Data.Map.Ordered qualified as OMap
10+ import Data.Maybe (fromMaybe )
911import Data.Set (Set )
1012import Data.Set qualified as Set
11- import Data.Text (Text )
1213import Data.Traversable (for )
1314import LambdaBuffers.Codegen.Config (cfgOpaques )
1415import LambdaBuffers.Codegen.Print (throwInternalError )
@@ -115,7 +116,10 @@ printTyBody _ tyN args (PC.OpaqueI si) = do
115116
116117printSum :: MonadPrint m => R. PkgMap -> PC. TyName -> [PC. TyArg ] -> PC. Sum -> m (Doc ann )
117118printSum pkgs parentTyN tyArgs (PC. Sum ctors _) = do
118- let phantomTyArgs = collectPhantomTyArgs (sumCtorTys ctors) tyArgs
119+ ci <- asks (view Print. ctxCompilerInput)
120+ let iTyDefs = indexTyDefs ci
121+ mn <- asks (view $ Print. ctxModule . # moduleName)
122+ let phantomTyArgs = collectPhantomTyArgs iTyDefs mn parentTyN (sumCtorTys ctors) tyArgs
119123 phantomCtor = if null phantomTyArgs then mempty else [printPhantomDataCtor phantomTyArgs]
120124 ctorDocs <- traverse (printCtor pkgs parentTyN) (toList ctors)
121125 if null ctors
@@ -139,7 +143,10 @@ printCtorInner pkgs parentTyN (PC.Product fields _) = do
139143
140144printRec :: MonadPrint m => R. PkgMap -> PC. TyName -> [PC. TyArg ] -> PC. Record -> m (Doc ann )
141145printRec pkgs parentTyN tyArgs (PC. Record fields _) = do
142- let phantomTyArgs = collectPhantomTyArgs (recFieldTys fields) tyArgs
146+ ci <- asks (view Print. ctxCompilerInput)
147+ let iTyDefs = indexTyDefs ci
148+ mn <- asks (view $ Print. ctxModule . # moduleName)
149+ let phantomTyArgs = collectPhantomTyArgs iTyDefs mn parentTyN (recFieldTys fields) tyArgs
143150 phantomFields = printPhantomDataField <$> phantomTyArgs
144151 if null fields && null phantomTyArgs
145152 then return semi
@@ -149,21 +156,55 @@ printRec pkgs parentTyN tyArgs (PC.Record fields _) = do
149156
150157printProd :: MonadPrint m => R. PkgMap -> PC. TyName -> [PC. TyArg ] -> PC. Product -> m (Doc ann )
151158printProd pkgs parentTyN tyArgs (PC. Product fields _) = do
159+ ci <- asks (view Print. ctxCompilerInput)
160+ let iTyDefs = indexTyDefs ci
161+ mn <- asks (view $ Print. ctxModule . # moduleName)
152162 tyDocs <- for fields (printTyTopLevel pkgs parentTyN)
153- let phantomTyArgs = collectPhantomTyArgs fields tyArgs
163+ let phantomTyArgs = collectPhantomTyArgs iTyDefs mn parentTyN fields tyArgs
154164 phantomFields = printPhantomData <$> phantomTyArgs
155165 if null fields && null phantomTyArgs
156166 then return semi
157167 else return $ encloseSep lparen rparen comma (tyDocs <> phantomFields) <> semi
158168
159- -- | Filter out unused type arguments in order to make PhantomData fields for them
160- collectPhantomTyArgs :: [PC. Ty ] -> [PC. TyArg ] -> [PC. TyArg ]
161- collectPhantomTyArgs tys tyArgs = foldr go tyArgs tys
169+ {- | Filter out unused type arguments in order to make PhantomData fields for them
170+ This is done in a recursive manner: if we encounter a type application, we resolve the type from TyDefs, and substitute
171+ all type variable with the arguments from the parent types type abstraction. We're also keeping track of all
172+ the type names already seen to avoid infinite recursions.
173+ -}
174+ collectPhantomTyArgs :: PC. TyDefs -> PC. ModuleName -> PC. TyName -> [PC. Ty ] -> [PC. TyArg ] -> [PC. TyArg ]
175+ collectPhantomTyArgs iTyDefs ownMn parentTyN tys tyArgs = foldr (go (Set. singleton (PC. mkInfoLess parentTyN))) tyArgs tys
162176 where
163- go :: PC. Ty -> [PC. TyArg ] -> [PC. TyArg ]
164- go (PC. TyVarI (PC. TyVar (PC. VarName varName _))) tyArgs' = filter (\ (PC. TyArg (PC. VarName varName' _) _ _) -> varName /= varName') tyArgs'
165- go (PC. TyAppI (PC. TyApp _ tys' _)) tyArgs' = foldr go tyArgs' tys'
166- go (PC. TyRefI _) tyArgs' = tyArgs'
177+ go :: Set (PC. InfoLess PC. TyName ) -> PC. Ty -> [PC. TyArg ] -> [PC. TyArg ]
178+ go _ (PC. TyVarI (PC. TyVar (PC. VarName varName _))) tyArgs' = filter (\ (PC. TyArg (PC. VarName varName' _) _ _) -> varName /= varName') tyArgs'
179+ go seenTys (PC. TyAppI (PC. TyApp tyFunc tys' _)) tyArgs' =
180+ case tyFunc of
181+ PC. TyRefI ref ->
182+ let qtyN@ (_, tyN) =
183+ case ref of
184+ PC. LocalI (PC. LocalRef tyN' _) -> (mkInfoLess ownMn, mkInfoLess tyN')
185+ PC. ForeignI (PC. ForeignRef tyN' mn _) -> (mkInfoLess mn, mkInfoLess tyN')
186+
187+ resolvedChildrenTys =
188+ case Map. lookup qtyN iTyDefs of
189+ Nothing -> [] -- TODO(szg251): Gracefully failing, but this should be an error instead
190+ Just (PC. TyDef _ (PC. TyAbs omap tyBody _) _) ->
191+ let tyAbsArgs = fst <$> OMap. assocs omap
192+ resolvedArgs = Map. fromList $ zip tyAbsArgs tys'
193+ tyBodyTys = case tyBody of
194+ PC. OpaqueI _ -> []
195+ PC. SumI (PC. Sum ctors _) -> sumCtorTys ctors
196+ PC. ProductI (PC. Product fields _) -> fields
197+ PC. RecordI (PC. Record fields _) -> recFieldTys fields
198+ in resolveTyVar resolvedArgs <$> tyBodyTys
199+ in if Set. member tyN seenTys
200+ then tyArgs'
201+ else foldr (go (Set. insert tyN seenTys)) tyArgs' resolvedChildrenTys
202+ _ -> tyArgs'
203+ go _ (PC. TyRefI _) tyArgs' = tyArgs'
204+
205+ resolveTyVar :: Map (PC. InfoLess PC. VarName ) PC. Ty -> PC. Ty -> PC. Ty
206+ resolveTyVar resolvedArgs ty@ (PC. TyVarI (PC. TyVar varName)) = fromMaybe ty $ Map. lookup (PC. mkInfoLess varName) resolvedArgs -- TODO(szg251): Should this be an error too? Guess so..
207+ resolveTyVar _ ty = ty
167208
168209-- | Returns Ty information of all record fields, sorted by field name
169210recFieldTys :: OMap (PC. InfoLess PC. FieldName ) PC. Field -> [PC. Ty ]
@@ -241,30 +282,30 @@ printTyInner pkgs (PC.TyAppI a) = printTyApp pkgs a
241282 This is done by resolving references, and searching for reoccurances of the parent type name
242283-}
243284isRecursive :: PC. TyDefs -> PC. ModuleName -> PC. TyName -> PC. Ty -> Bool
244- isRecursive iTyDefs mn ( PC. TyName parentTyNameT _) = go mempty
285+ isRecursive iTyDefs ownMn parentTyName = go mempty
245286 where
246- go :: Set Text -> PC. Ty -> Bool
287+ go :: Set ( PC. InfoLess PC. TyName ) -> PC. Ty -> Bool
247288 go _ (PC. TyVarI _) = False
248- go otherTys (PC. TyAppI (PC. TyApp _ tyArgs _)) = any (go otherTys) tyArgs
289+ go otherTys (PC. TyAppI (PC. TyApp tyFunc tyArgs _)) = any (go otherTys) $ tyFunc : tyArgs
249290 go otherTys (PC. TyRefI ref) = do
250- let ( qtyN, tyN) =
291+ let qtyN@ (_ , tyN) =
251292 case ref of
252- PC. LocalI (PC. LocalRef tyN' _) ->
253- let (PC. TyName tyNameT _) = tyN'
254- in ((mkInfoLess mn, mkInfoLess tyN'), tyNameT)
255- PC. ForeignI (PC. ForeignRef tyN' mn' _) ->
256- let (PC. TyName tyNameT _) = tyN'
257- in ((mkInfoLess mn', mkInfoLess tyN'), tyNameT)
258-
259- let childrenTys =
260- case Map. lookup qtyN iTyDefs of
261- Nothing -> [] -- TODO(szg251): Gracefully failing, but this should be an error instead
262- Just (PC. TyDef _ (PC. TyAbs _ tyBody _) _) ->
263- case tyBody of
264- PC. OpaqueI _ -> []
265- PC. SumI (PC. Sum ctors _) -> sumCtorTys ctors
266- PC. ProductI (PC. Product fields _) -> fields
267- PC. RecordI (PC. Record fields _) -> recFieldTys fields
268-
269- (parentTyNameT == tyN)
293+ PC. LocalI (PC. LocalRef tyN' _) -> (mkInfoLess ownMn, mkInfoLess tyN')
294+ PC. ForeignI (PC. ForeignRef tyN' mn _) -> (mkInfoLess mn, mkInfoLess tyN')
295+
296+ let childrenTys = findChildren iTyDefs qtyN
297+
298+ (PC. mkInfoLess parentTyName == tyN)
270299 || (not (Set. member tyN otherTys) && any (go (Set. insert tyN otherTys)) childrenTys)
300+
301+ -- | Resolve a qualified type name and return all it's children types
302+ findChildren :: PC. TyDefs -> PC. QTyName -> [PC. Ty ]
303+ findChildren iTyDefs qtyN =
304+ case Map. lookup qtyN iTyDefs of
305+ Nothing -> [] -- TODO(szg251): Gracefully failing, but this should be an error instead
306+ Just (PC. TyDef _ (PC. TyAbs _ tyBody _) _) ->
307+ case tyBody of
308+ PC. OpaqueI _ -> []
309+ PC. SumI (PC. Sum ctors _) -> sumCtorTys ctors
310+ PC. ProductI (PC. Product fields _) -> fields
311+ PC. RecordI (PC. Record fields _) -> recFieldTys fields
0 commit comments