Skip to content

Commit 3895701

Browse files
authored
Merge pull request #29 from mlabs-haskell/bladyjoker/compiler-typeclass-check
Compiler: TypeClass checks
2 parents 2e2a213 + c90bb2b commit 3895701

File tree

4 files changed

+151
-5
lines changed

4 files changed

+151
-5
lines changed

lambda-buffers-compiler/lambda-buffers-compiler.cabal

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ library
8383
import: common-language
8484
build-depends:
8585
, base >=4.16
86+
, containers >=0.6
8687
, freer-simple >=1.2
8788
, generic-lens >=2.2
8889
, lambda-buffers-compiler-pb >=0.1.0.0
@@ -99,6 +100,7 @@ library
99100
LambdaBuffers.Compiler.NamingCheck
100101
LambdaBuffers.Compiler.ProtoCompat
101102
LambdaBuffers.Compiler.ProtoCompat.Types
103+
LambdaBuffers.Compiler.TypeClassCheck
102104

103105
hs-source-dirs: src
104106

@@ -124,9 +126,15 @@ test-suite tests
124126
hs-source-dirs: test
125127
main-is: Test.hs
126128
build-depends:
127-
, base >=4.16
129+
, base >=4.16
128130
, lambda-buffers-compiler
129-
, tasty >=1.4
130-
, tasty-hunit >=0.10
131+
, lambda-buffers-compiler-pb >=0.1
132+
, lens >=5.2
133+
, proto-lens >=0.7
134+
, tasty >=1.4
135+
, tasty-hunit >=0.10
136+
, text >=1.2
131137

132-
other-modules: Test.KindCheck
138+
other-modules:
139+
Test.KindCheck
140+
Test.TypeClassCheck
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
{-# LANGUAGE OverloadedLabels #-}
2+
{-# LANGUAGE OverloadedStrings #-}
3+
4+
module LambdaBuffers.Compiler.TypeClassCheck (detectSuperclassCycles, detectSuperclassCycles') where
5+
6+
import Control.Lens.Operators ((^.))
7+
import Data.Generics.Labels ()
8+
import Data.List (foldl')
9+
import Data.Map qualified as M
10+
import Data.Text (Text)
11+
import LambdaBuffers.Compiler.ProtoCompat.Types (
12+
ClassDef (),
13+
)
14+
import Prettyprinter (
15+
Doc,
16+
Pretty (pretty),
17+
hcat,
18+
indent,
19+
line,
20+
punctuate,
21+
vcat,
22+
)
23+
24+
data ClassInfo = ClassInfo {ciName :: Text, ciSupers :: [Text]}
25+
deriving stock (Show, Eq, Ord)
26+
27+
detectSuperclassCycles' :: [ClassDef] -> [[Text]]
28+
detectSuperclassCycles' = detectCycles . mkClassGraph . map defToClassInfo
29+
where
30+
defToClassInfo :: ClassDef -> ClassInfo
31+
defToClassInfo cd =
32+
ClassInfo (cd ^. #className . #name) $
33+
map (\x -> x ^. #className . #name) (cd ^. #supers)
34+
35+
mkClassGraph :: [ClassInfo] -> M.Map Text [Text]
36+
mkClassGraph = foldl' (\acc (ClassInfo nm sups) -> M.insert nm sups acc) M.empty
37+
38+
detectCycles :: forall k. Ord k => M.Map k [k] -> [[k]]
39+
detectCycles m = concatMap (detect []) (M.keys m)
40+
where
41+
detect :: [k] -> k -> [[k]]
42+
detect visited x = case M.lookup x m of
43+
Nothing -> []
44+
Just xs ->
45+
if x `elem` visited
46+
then [x : visited]
47+
else concatMap (detect (x : visited)) xs
48+
49+
detectSuperclassCycles :: forall a. [ClassDef] -> Maybe (Doc a)
50+
detectSuperclassCycles cds = case detectSuperclassCycles' cds of
51+
[] -> Nothing
52+
xs ->
53+
Just $
54+
"Error: Superclass cycle(s) detected"
55+
<> line
56+
<> indent 2 (vcat $ map format xs)
57+
where
58+
format :: [Text] -> Doc a
59+
format = hcat . punctuate " => " . map pretty

lambda-buffers-compiler/test/Test.hs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ module Main (main) where
22

33
import Test.KindCheck qualified as KC
44
import Test.Tasty (defaultMain, testGroup)
5+
import Test.TypeClassCheck qualified as TC
56

67
main :: IO ()
7-
main = defaultMain $ testGroup "Compiler tests" [KC.test]
8+
main =
9+
defaultMain $
10+
testGroup
11+
"Compiler tests"
12+
[ KC.test
13+
, TC.test
14+
]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
module Test.TypeClassCheck (test) where
2+
3+
import Control.Lens ((.~))
4+
import Data.Function ((&))
5+
import Data.ProtoLens (Message (defMessage))
6+
import Data.Text (Text)
7+
import Data.Traversable (for)
8+
import LambdaBuffers.Compiler.ProtoCompat (IsMessage (fromProto))
9+
import LambdaBuffers.Compiler.ProtoCompat.Types qualified as ProtoCompat
10+
import LambdaBuffers.Compiler.TypeClassCheck (detectSuperclassCycles')
11+
import Proto.Compiler (ClassDef, Constraint, Kind, Kind'KindRef (Kind'KIND_REF_TYPE))
12+
import Proto.Compiler_Fields (argKind, argName, arguments, classArgs, className, kindRef, name, supers, tyVar, varName)
13+
import Test.Tasty (TestTree, testGroup)
14+
import Test.Tasty.HUnit (assertFailure, testCase, (@?=))
15+
16+
test :: TestTree
17+
test =
18+
testGroup
19+
"TypeClassCheck tests"
20+
[ noCycleDetected
21+
, cycleDetected
22+
]
23+
24+
noCycleDetected :: TestTree
25+
noCycleDetected =
26+
testCase "No cycle detected" $ do
27+
nocycles' <- classDefsFromProto nocycles
28+
detectSuperclassCycles' nocycles' @?= []
29+
30+
cycleDetected :: TestTree
31+
cycleDetected =
32+
testCase "Cycle detected" $ do
33+
cycles' <- classDefsFromProto cycles
34+
detectSuperclassCycles' cycles' @?= [["Bar", "Foo", "Bop", "Bar"], ["Bop", "Bar", "Foo", "Bop"], ["Foo", "Bop", "Bar", "Foo"]]
35+
36+
classDefsFromProto :: [ClassDef] -> IO [ProtoCompat.ClassDef]
37+
classDefsFromProto cds = for cds (either (\err -> assertFailure $ "FromProto failed with " <> show err) return . fromProto @ClassDef @ProtoCompat.ClassDef)
38+
39+
star :: Kind
40+
star = defMessage & kindRef .~ Kind'KIND_REF_TYPE
41+
42+
mkclass :: Text -> [Text] -> ClassDef
43+
mkclass nm sups =
44+
defMessage
45+
& className . name .~ nm
46+
& classArgs
47+
.~ [ defMessage
48+
& argName . name .~ "a"
49+
& argKind .~ star
50+
]
51+
& supers .~ map constraint sups
52+
53+
constraint :: Text -> Constraint
54+
constraint nm =
55+
defMessage
56+
& className . name .~ nm
57+
& arguments .~ [defMessage & tyVar . varName . name .~ "a"]
58+
59+
cycles :: [ClassDef]
60+
cycles =
61+
[ mkclass "Foo" ["Bar", "Baz", "Beep"]
62+
, mkclass "Bar" ["Bip", "Bop"]
63+
, mkclass "Bop" ["Foo"]
64+
]
65+
66+
nocycles :: [ClassDef]
67+
nocycles =
68+
[ mkclass "Functor" []
69+
, mkclass "Applicative" ["Functor"]
70+
, mkclass "Monad" ["Applicative"]
71+
, mkclass "Traversable" ["Foldable", "Functor"]
72+
]

0 commit comments

Comments
 (0)