11{-# LANGUAGE OverloadedLabels #-}
22{-# LANGUAGE OverloadedStrings #-}
33
4- module LambdaBuffers.Compiler.TypeClassCheck (detectSuperclassCycles , detectSuperclassCycles' ) where
4+ module LambdaBuffers.Compiler.TypeClassCheck (detectSuperclassCycles , detectSuperclassCycles' , runDeriveCheck , validateTypeClasses ) where
55
66import Control.Lens.Combinators (view )
77import Control.Lens.Operators ((^.) )
8+ import Control.Monad (void )
89import Data.Generics.Labels ()
910import Data.List (foldl' )
11+ import Data.Map (traverseWithKey )
1012import Data.Map qualified as M
13+ import Data.Set qualified as S
1114import Data.Text (Text )
15+ import LambdaBuffers.Compiler.ProtoCompat qualified as P
1216import LambdaBuffers.Compiler.ProtoCompat.Types (
1317 ClassDef (),
1418 ForeignClassRef (ForeignClassRef ),
1519 LocalClassRef (LocalClassRef ),
1620 TyClassRef (ForeignCI , LocalCI ),
1721 )
22+ import LambdaBuffers.Compiler.TypeClassCheck.Pretty (spaced , (<//>) )
23+ import LambdaBuffers.Compiler.TypeClassCheck.Utils (
24+ Instance ,
25+ ModuleBuilder (mbInstances ),
26+ TypeClassError (FailedToSolveConstraints ),
27+ checkInstance ,
28+ mkBuilders ,
29+ )
30+ import LambdaBuffers.Compiler.TypeClassCheck.Validate (checkDerive )
1831import Prettyprinter (
1932 Doc ,
2033 Pretty (pretty ),
@@ -23,6 +36,7 @@ import Prettyprinter (
2336 line ,
2437 punctuate ,
2538 vcat ,
39+ (<+>) ,
2640 )
2741
2842data ClassInfo = ClassInfo { ciName :: Text , ciSupers :: [Text ]}
@@ -65,3 +79,39 @@ detectSuperclassCycles cds = case detectSuperclassCycles' cds of
6579 where
6680 format :: [Text ] -> Doc a
6781 format = hcat . punctuate " => " . map pretty
82+
83+ runDeriveCheck :: P. ModuleName -> ModuleBuilder -> Either TypeClassError ()
84+ runDeriveCheck mn mb = mconcat <$> traverse go (S. toList $ mbInstances mb)
85+ where
86+ go :: Instance -> Either TypeClassError ()
87+ go i =
88+ checkInstance i
89+ >> checkDerive mn mb i
90+ >>= \ case
91+ [] -> pure ()
92+ xs -> Left $ FailedToSolveConstraints mn xs i
93+
94+ -- ModuleBuilder is suitable codegen input,
95+ -- and is (relatively) computationally expensive to
96+ -- construct, so we return it here if successful.
97+ validateTypeClasses' :: P. CompilerInput -> Either TypeClassError (M. Map P. ModuleName ModuleBuilder )
98+ validateTypeClasses' ci = do
99+ -- detectSuperclassCycles ci
100+ moduleBuilders <- mkBuilders ci
101+ void $ traverseWithKey runDeriveCheck moduleBuilders
102+ pure moduleBuilders
103+
104+ -- maybe use Control.Exception? Tho if we're not gonna catch it i guess this is fine
105+ validateTypeClasses :: P. CompilerInput -> IO (M. Map P. ModuleName ModuleBuilder )
106+ validateTypeClasses ci = case validateTypeClasses' ci of
107+ Left err -> print (spaced $ pretty err) >> error " \n Compilation aborted due to TypeClass Error"
108+ Right mbs -> print (prettyBuilders mbs) >> pure mbs
109+
110+ prettyBuilders :: forall a . M. Map P. ModuleName ModuleBuilder -> Doc a
111+ prettyBuilders = spaced . vcat . punctuate line . map (uncurry go) . M. toList
112+ where
113+ go :: P. ModuleName -> ModuleBuilder -> Doc a
114+ go mn mb =
115+ " MODULE"
116+ <+> pretty mn
117+ <//> indent 2 (pretty mb)
0 commit comments