Skip to content

Commit e6824cb

Browse files
committed
Added @gnumonik's PR #23 into this fresh one and scaffolded tests
1 parent 2e2a213 commit e6824cb

File tree

4 files changed

+164
-5
lines changed

4 files changed

+164
-5
lines changed

lambda-buffers-compiler/lambda-buffers-compiler.cabal

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ library
8383
import: common-language
8484
build-depends:
8585
, base >=4.16
86+
, containers >=0.6
8687
, freer-simple >=1.2
8788
, generic-lens >=2.2
8889
, lambda-buffers-compiler-pb >=0.1.0.0
@@ -99,6 +100,7 @@ library
99100
LambdaBuffers.Compiler.NamingCheck
100101
LambdaBuffers.Compiler.ProtoCompat
101102
LambdaBuffers.Compiler.ProtoCompat.Types
103+
LambdaBuffers.Compiler.TypeClassCheck
102104

103105
hs-source-dirs: src
104106

@@ -124,9 +126,15 @@ test-suite tests
124126
hs-source-dirs: test
125127
main-is: Test.hs
126128
build-depends:
127-
, base >=4.16
129+
, base >=4.16
128130
, lambda-buffers-compiler
129-
, tasty >=1.4
130-
, tasty-hunit >=0.10
131+
, lambda-buffers-compiler-pb >=0.1
132+
, lens >=5.2
133+
, proto-lens >=0.7
134+
, tasty >=1.4
135+
, tasty-hunit >=0.10
136+
, text >=1.2
131137

132-
other-modules: Test.KindCheck
138+
other-modules:
139+
Test.KindCheck
140+
Test.TypeClassCheck
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
{-# LANGUAGE OverloadedLabels #-}
2+
{-# LANGUAGE OverloadedStrings #-}
3+
4+
module LambdaBuffers.Compiler.TypeClassCheck (detectSuperclassCycles, detectSuperclassCycles') where
5+
6+
import Control.Lens.Operators ((^.))
7+
import Data.Generics.Labels ()
8+
import Data.List (foldl')
9+
import Data.Map qualified as M
10+
import Data.Text (Text)
11+
import LambdaBuffers.Compiler.ProtoCompat.Types (
12+
ClassDef (ClassDef),
13+
ClassName (ClassName),
14+
Constraint (Constraint),
15+
Kind (Kind),
16+
KindRefType (KType),
17+
KindType (KindRef),
18+
LocalRef (LocalRef),
19+
SourceInfo (SourceInfo),
20+
SourcePosition (SourcePosition),
21+
Ty (TyRefI),
22+
TyArg (TyArg),
23+
TyName (TyName),
24+
TyRef (LocalI),
25+
VarName (VarName),
26+
)
27+
import Prettyprinter (
28+
Doc,
29+
Pretty (pretty),
30+
hcat,
31+
indent,
32+
line,
33+
punctuate,
34+
vcat,
35+
)
36+
37+
data ClassInfo = ClassInfo {ciName :: Text, ciSupers :: [Text]}
38+
deriving stock (Show, Eq, Ord)
39+
40+
detectSuperclassCycles' :: [ClassDef] -> [[Text]]
41+
detectSuperclassCycles' = detectCycles . mkClassGraph . map defToClassInfo
42+
where
43+
defToClassInfo :: ClassDef -> ClassInfo
44+
defToClassInfo cd =
45+
ClassInfo (cd ^. #className . #name) $
46+
map (\x -> x ^. #className . #name) (cd ^. #supers)
47+
48+
mkClassGraph :: [ClassInfo] -> M.Map Text [Text]
49+
mkClassGraph = foldl' (\acc (ClassInfo nm sups) -> M.insert nm sups acc) M.empty
50+
51+
detectCycles :: forall k. Ord k => M.Map k [k] -> [[k]]
52+
detectCycles m = concatMap (detect []) (M.keys m)
53+
where
54+
detect :: [k] -> k -> [[k]]
55+
detect visited x = case M.lookup x m of
56+
Nothing -> []
57+
Just xs ->
58+
if x `elem` visited
59+
then [x : visited]
60+
else concatMap (detect (x : visited)) xs
61+
62+
detectSuperclassCycles :: forall a. [ClassDef] -> Maybe (Doc a)
63+
detectSuperclassCycles cds = case detectSuperclassCycles' cds of
64+
[] -> Nothing
65+
xs ->
66+
Just $
67+
"Error: Superclass cycle(s) detected"
68+
<> line
69+
<> indent 2 (vcat $ map format xs)
70+
where
71+
format :: [Text] -> Doc a
72+
format = hcat . punctuate " => " . map pretty

lambda-buffers-compiler/test/Test.hs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ module Main (main) where
22

33
import Test.KindCheck qualified as KC
44
import Test.Tasty (defaultMain, testGroup)
5+
import Test.TypeClassCheck qualified as TC
56

67
main :: IO ()
7-
main = defaultMain $ testGroup "Compiler tests" [KC.test]
8+
main =
9+
defaultMain $
10+
testGroup
11+
"Compiler tests"
12+
[ KC.test
13+
, TC.test
14+
]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
module Test.TypeClassCheck (test) where
2+
3+
import Control.Lens ((.~))
4+
import Data.Function ((&))
5+
import Data.ProtoLens (Message (defMessage))
6+
import Data.Text (Text)
7+
import Data.Traversable (for)
8+
import LambdaBuffers.Compiler.ProtoCompat (IsMessage (fromProto))
9+
import LambdaBuffers.Compiler.ProtoCompat.Types qualified as ProtoCompat
10+
import LambdaBuffers.Compiler.TypeClassCheck (detectSuperclassCycles')
11+
import Proto.Compiler (ClassDef, Constraint, Kind, Kind'KindRef (Kind'KIND_REF_TYPE))
12+
import Proto.Compiler_Fields (argKind, argName, arguments, classArgs, className, kindRef, name, supers, tyVar, varName)
13+
import Test.Tasty (TestTree, testGroup)
14+
import Test.Tasty.HUnit (assertFailure, testCase, (@?=))
15+
16+
test :: TestTree
17+
test =
18+
testGroup
19+
"TypeClassCheck tests"
20+
[ noCycleDetected
21+
, cycleDetected
22+
]
23+
24+
noCycleDetected :: TestTree
25+
noCycleDetected =
26+
testCase "No cycle detected" $ do
27+
nocycles' <- classDefsFromProto nocycles
28+
detectSuperclassCycles' nocycles' @?= []
29+
30+
cycleDetected :: TestTree
31+
cycleDetected =
32+
testCase "Cycle detected" $ do
33+
cycles' <- classDefsFromProto cycles
34+
detectSuperclassCycles' cycles' @?= [["Bar", "Foo", "Bop", "Bar"], ["Bop", "Bar", "Foo", "Bop"], ["Foo", "Bop", "Bar", "Foo"]]
35+
36+
classDefsFromProto :: [ClassDef] -> IO [ProtoCompat.ClassDef]
37+
classDefsFromProto cds = for cds (either (\err -> assertFailure $ "FromProto failed with " <> show err) return . fromProto @ClassDef @ProtoCompat.ClassDef)
38+
39+
star :: Kind
40+
star = defMessage & kindRef .~ Kind'KIND_REF_TYPE
41+
42+
mkclass :: Text -> [Text] -> ClassDef
43+
mkclass nm sups =
44+
defMessage
45+
& className . name .~ nm
46+
& classArgs
47+
.~ [ defMessage
48+
& argName . name .~ "a"
49+
& argKind .~ star
50+
]
51+
& supers .~ map constraint sups
52+
53+
constraint :: Text -> Constraint
54+
constraint nm =
55+
defMessage
56+
& className . name .~ nm
57+
& arguments .~ [defMessage & tyVar . varName . name .~ "a"]
58+
59+
cycles :: [ClassDef]
60+
cycles =
61+
[ mkclass "Foo" ["Bar", "Baz", "Beep"]
62+
, mkclass "Bar" ["Bip", "Bop"]
63+
, mkclass "Bop" ["Foo"]
64+
]
65+
66+
nocycles :: [ClassDef]
67+
nocycles =
68+
[ mkclass "Functor" []
69+
, mkclass "Applicative" ["Functor"]
70+
, mkclass "Monad" ["Applicative"]
71+
, mkclass "Traversable" ["Foldable", "Functor"]
72+
]

0 commit comments

Comments
 (0)