Skip to content

Commit 988ea17

Browse files
committed
update: refactor complete
1 parent 621a263 commit 988ea17

File tree

13 files changed

+348
-499
lines changed

13 files changed

+348
-499
lines changed

lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck.hs

Lines changed: 43 additions & 328 deletions
Large diffs are not rendered by default.

lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Derivation.hs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
module LambdaBuffers.Compiler.KindCheck.Derivation (
2-
Derivation (Axiom, Abstraction, Application),
2+
Derivation (Axiom, Abstraction, Application, Implication),
3+
dType,
4+
dTopKind,
35
) where
46

5-
import LambdaBuffers.Compiler.KindCheck.Judgement (Judgement)
7+
import Control.Lens (Lens', lens, (&), (.~), (^.))
8+
import LambdaBuffers.Compiler.KindCheck.Judgement (Judgement, jKind, jType)
9+
import LambdaBuffers.Compiler.KindCheck.Kind (Kind)
10+
import LambdaBuffers.Compiler.KindCheck.Type (Type)
611
import Prettyprinter (
712
Doc,
813
Pretty (pretty),
@@ -18,13 +23,45 @@ data Derivation
1823
= Axiom Judgement
1924
| Abstraction Judgement Derivation
2025
| Application Judgement Derivation Derivation
26+
| Implication Judgement Derivation
2127
deriving stock (Show, Eq)
2228

2329
instance Pretty Derivation where
2430
pretty x = case x of
2531
Axiom j -> hang 2 $ pretty j
2632
Abstraction j d -> dNest j [d]
2733
Application j d1 d2 -> dNest j [d1, d2]
34+
Implication j d -> dNest j [d]
2835
where
2936
dNest :: forall a b c. (Pretty a, Pretty b) => a -> [b] -> Doc c
3037
dNest j ds = pretty j <> line <> hang 2 (encloseSep (lbracket <> space) rbracket (space <> "" <> space) (pretty <$> ds))
38+
39+
dType :: Lens' Derivation Type
40+
dType = lens from to
41+
where
42+
from = \case
43+
Axiom j -> j ^. jType
44+
Abstraction j _ -> j ^. jType
45+
Application j _ _ -> j ^. jType
46+
Implication j _ -> j ^. jType
47+
48+
to drv t = case drv of
49+
Axiom j -> Axiom $ j & jType .~ t
50+
Abstraction j d -> Abstraction (j & jType .~ t) d
51+
Application j d1 d2 -> Application (j & jType .~ t) d1 d2
52+
Implication j d -> Implication (j & jType .~ t) d
53+
54+
dTopKind :: Lens' Derivation Kind
55+
dTopKind = lens from to
56+
where
57+
from = \case
58+
Axiom j -> j ^. jKind
59+
Abstraction j _ -> j ^. jKind
60+
Application j _ _ -> j ^. jKind
61+
Implication j _ -> j ^. jKind
62+
63+
to der t = case der of
64+
Axiom j -> Axiom $ j & jKind .~ t
65+
Abstraction j d -> Abstraction (j & jKind .~ t) d
66+
Application j d1 d2 -> Application (j & jKind .~ t) d1 d2
67+
Implication j d -> Abstraction (j & jKind .~ t) d

lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Inference.hs

Lines changed: 151 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# OPTIONS_GHC -Wno-missing-local-signatures #-}
12
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
23

34
-- This pragma^ is needed due to redundant constraint in Getter.
@@ -16,24 +17,27 @@ module LambdaBuffers.Compiler.KindCheck.Inference (
1617
) where
1718

1819
import Data.Bifunctor (Bifunctor (second))
20+
import Data.Foldable (foldrM)
1921

2022
import LambdaBuffers.Compiler.KindCheck.Context (Context (Context), addContext, context, getAllContext)
21-
import LambdaBuffers.Compiler.KindCheck.Derivation (Derivation (Abstraction, Application, Axiom))
23+
import LambdaBuffers.Compiler.KindCheck.Derivation (Derivation (Abstraction, Application, Axiom, Implication), dTopKind, dType)
2224
import LambdaBuffers.Compiler.KindCheck.Judgement (Judgement (Judgement))
23-
import LambdaBuffers.Compiler.KindCheck.Kind (Kind (KVar, Type, (:->:)))
24-
import LambdaBuffers.Compiler.KindCheck.Type (Type (Abs, App, Var), tyOpaque, tyProd, tySum, tyUnit, tyVoid)
25-
import LambdaBuffers.Compiler.KindCheck.Variable (Atom, Variable)
25+
import LambdaBuffers.Compiler.KindCheck.Kind (Kind (KType, KVar, (:->:)))
26+
import LambdaBuffers.Compiler.KindCheck.Type (Type (Abs, App, Constructor, Opaque, Product, Sum, Var, VoidT))
27+
import LambdaBuffers.Compiler.KindCheck.Variable (Atom, Variable (ForeignRef, LocalRef, TyVar))
2628

2729
import Control.Monad.Freer (Eff, Member, Members, run)
2830
import Control.Monad.Freer.Error (Error, runError, throwError)
29-
import Control.Monad.Freer.Reader (Reader, ask, asks, local, runReader)
31+
import Control.Monad.Freer.Reader (Reader, ask, asks, runReader)
3032
import Control.Monad.Freer.State (State, evalState, get, put)
3133
import Control.Monad.Freer.Writer (Writer, runWriter, tell)
3234

35+
import LambdaBuffers.Compiler.ProtoCompat qualified as PC
36+
3337
import Data.String (fromString)
3438
import Data.Text qualified as T
3539

36-
import Control.Lens (Getter, to, (&), (.~), (^.))
40+
import Control.Lens ((&), (.~), (^.))
3741
import Data.Map qualified as M
3842

3943
import Prettyprinter (
@@ -71,6 +75,7 @@ type Derive a =
7175
forall effs.
7276
Members
7377
'[ Reader Context
78+
, Reader Kind
7479
, State DerivationContext
7580
, Writer [Constraint]
7681
, Error InferErr
@@ -82,55 +87,26 @@ type Derive a =
8287
-- Runners
8388

8489
-- | Run derivation builder - not unified yet.
85-
runDerive :: Context -> Type -> Either InferErr (Derivation, [Constraint])
86-
runDerive ctx t = run $ runError $ runWriter $ evalState (DC atoms) $ runReader ctx (derive t)
90+
runDerive :: Context -> PC.TyAbs -> Kind -> Either InferErr (Derivation, [Constraint])
91+
runDerive ctx t k = run $ runError $ runWriter $ evalState (DC atoms) $ runReader ctx $ runReader k (derive t)
8792

88-
infer :: Context -> Type -> Either InferErr Kind
89-
infer ctx t = do
90-
(d, c) <- runDerive (defContext <> ctx) t
93+
infer :: Context -> PC.TyDef -> Kind -> Either InferErr Kind
94+
infer ctx t k = do
95+
(d, c) <- runDerive (defContext <> ctx) (t ^. #tyAbs) k
9196
s <- runUnify' c
9297
let res = foldl (flip substitute) d s
93-
pure $ res ^. topKind
98+
pure $ res ^. dTopKind
9499

95100
-- | Default KC Context.
96101
defContext :: Context
97-
defContext =
98-
mempty
99-
& context
100-
.~ M.fromList
101-
[ (tySum, Type :->: Type :->: Type)
102-
, (tyProd, Type :->: Type :->: Type)
103-
, (tyUnit, Type)
104-
, (tyVoid, Type)
105-
, (tyOpaque, Type)
106-
]
102+
defContext = mempty
107103

108104
--------------------------------------------------------------------------------
109105
-- Implementation
110106

111107
-- | Creates the derivation
112-
derive :: Type -> Derive Derivation
113-
derive x = do
114-
c <- ask
115-
case x of
116-
Var at -> do
117-
v <- getBinding at
118-
pure $ Axiom $ Judgement (c, x, v)
119-
App t1 t2 -> do
120-
d1 <- derive t1
121-
d2 <- derive t2
122-
let ty1 = d1 ^. topKind
123-
ty2 = d2 ^. topKind
124-
v <- KVar <$> fresh
125-
tell [Constraint (ty1, ty2 :->: v)]
126-
pure $ Application (Judgement (c, x, v)) d1 d2
127-
Abs v t -> do
128-
newTy <- getBinding v
129-
d <- local (\(Context ctx addC) -> Context ctx $ M.insert v newTy addC) (derive t)
130-
let ty = d ^. topKind
131-
freshT <- KVar <$> fresh
132-
tell [Constraint (freshT, newTy :->: ty)]
133-
pure $ Abstraction (Judgement (c, x, freshT)) d
108+
derive :: PC.TyAbs -> Derive Derivation
109+
derive x = deriveTyAbs x
134110
where
135111
fresh :: Derive Atom
136112
fresh = do
@@ -139,6 +115,130 @@ derive x = do
139115
a : as -> put (DC as) >> pure a
140116
[] -> throwError $ InferImpossibleErr "Reached end of infinite stream."
141117

118+
deriveTyAbs :: PC.TyAbs -> Derive Derivation
119+
deriveTyAbs tyabs =
120+
case M.toList (tyabs ^. #tyArgs) of
121+
[] -> deriveTyBody (x ^. #tyBody)
122+
a@(n, _) : as -> do
123+
vK <- getBinding (TyVar n)
124+
freshT <- KVar <$> fresh
125+
let newAbs = tyabs & #tyArgs .~ uncurry M.singleton a
126+
let restAbs = tyabs & #tyArgs .~ M.fromList as
127+
restF <- deriveTyAbs restAbs
128+
let uK = restF ^. dTopKind
129+
tell [Constraint (freshT, uK)]
130+
ctx <- ask
131+
pure $ Abstraction (Judgement (ctx, Abs newAbs, vK :->: freshT)) restF
132+
133+
deriveTyBody :: PC.TyBody -> Derive Derivation
134+
deriveTyBody = \case
135+
PC.OpaqueI si -> do
136+
ctx <- ask
137+
pure $ Axiom $ Judgement (ctx, Opaque si, KType)
138+
PC.SumI s -> deriveSum s
139+
140+
deriveSum :: PC.Sum -> Derive Derivation
141+
deriveSum s = do
142+
case M.toList (s ^. #constructors) of
143+
[] -> voidDerivation
144+
c : cs -> do
145+
dc <- deriveConstructor $ snd c
146+
restDc <- deriveSum $ s & #constructors .~ M.fromList cs
147+
sumDerivation dc restDc
148+
149+
deriveConstructor :: PC.Constructor -> Derive Derivation
150+
deriveConstructor c = do
151+
ctx <- ask
152+
d <- deriveProduct (c ^. #product)
153+
tell $ Constraint <$> [(KType, d ^. dTopKind)]
154+
pure $ Implication (Judgement (ctx, Constructor c, d ^. dTopKind)) d
155+
156+
deriveProduct :: PC.Product -> Derive Derivation
157+
deriveProduct = \case
158+
PC.RecordI r -> deriveRecord r
159+
PC.TupleI t -> deriveTuple t
160+
161+
deriveRecord r = do
162+
case M.toList (r ^. #fields) of
163+
[] -> voidDerivation
164+
f : fs -> do
165+
d1 <- deriveField $ snd f
166+
d2 <- deriveRecord $ r & #fields .~ M.fromList fs
167+
productDerivation d1 d2
168+
169+
deriveField :: PC.Field -> Derive Derivation
170+
deriveField f = deriveTy $ f ^. #fieldTy
171+
172+
deriveTy :: PC.Ty -> Derive Derivation
173+
deriveTy = \case
174+
PC.TyVarI tv -> deriveTyVar tv
175+
PC.TyAppI ta -> deriveTyApp ta
176+
PC.TyRefI tr -> deriveTyRef tr
177+
178+
deriveTyRef :: PC.TyRef -> Derive Derivation
179+
deriveTyRef = \case
180+
PC.LocalI r -> do
181+
let ty = LocalRef r
182+
v <- getBinding ty
183+
c <- ask
184+
pure . Axiom . Judgement $ (c, Var ty, v)
185+
PC.ForeignI r -> do
186+
let ty = ForeignRef r
187+
v <- getBinding ty
188+
c <- ask
189+
pure . Axiom . Judgement $ (c, Var ty, v)
190+
191+
deriveTyVar :: PC.TyVar -> Derive Derivation
192+
deriveTyVar tv = do
193+
let varName = tv ^. #varName
194+
v <- getBinding $ TyVar varName
195+
c <- ask
196+
pure . Axiom . Judgement $ (c, Var $ TyVar varName, v)
197+
198+
deriveTyApp :: PC.TyApp -> Derive Derivation
199+
deriveTyApp ap = do
200+
f <- deriveTy (ap ^. #tyFunc)
201+
args <- deriveTy `traverse` (ap ^. #tyArgs)
202+
applyDerivation $ f : args
203+
204+
deriveTuple :: PC.Tuple -> Derive Derivation
205+
deriveTuple t = do
206+
voidD <- voidDerivation
207+
ds <- deriveTy `traverse` (t ^. #fields)
208+
foldrM productDerivation voidD ds
209+
210+
voidDerivation :: Derive Derivation
211+
voidDerivation = do
212+
ctx <- ask
213+
pure $ Axiom $ Judgement (ctx, VoidT, KType)
214+
215+
productDerivation :: Derivation -> Derivation -> Derive Derivation
216+
productDerivation d1 d2 = do
217+
ctx <- ask
218+
let t1 = d1 ^. dType
219+
let t2 = d2 ^. dType
220+
tell $ Constraint <$> [(d1 ^. dTopKind, KType), (d2 ^. dTopKind, KType)]
221+
pure $ Application (Judgement (ctx, Product t1 t2, KType)) d1 d2
222+
223+
sumDerivation :: Derivation -> Derivation -> Derive Derivation
224+
sumDerivation d1 d2 = do
225+
ctx <- ask
226+
let t1 = d1 ^. dType
227+
let t2 = d2 ^. dType
228+
tell $ Constraint <$> [(d1 ^. dTopKind, KType), (d2 ^. dTopKind, KType)]
229+
pure $ Application (Judgement (ctx, Sum t1 t2, KType)) d1 d2
230+
231+
applyDerivation :: [Derivation] -> Derive Derivation
232+
applyDerivation = \case
233+
[] -> error "Impossible"
234+
[y] -> pure y
235+
d1 : ys -> do
236+
c <- ask
237+
d2 <- applyDerivation ys
238+
v <- KVar <$> fresh
239+
tell [Constraint ((d2 ^. dTopKind) :->: v, d1 ^. dTopKind)]
240+
pure $ Application (Judgement (c, App (d1 ^. dType) (d2 ^. dType), v)) d1 d2
241+
142242
{- | Gets the binding from the context - if the variable is not bound throw an
143243
error.
144244
-}
@@ -149,22 +249,13 @@ getBinding t = do
149249
Just x -> pure x
150250
Nothing -> throwError $ InferUnboundTermErr t
151251

152-
-- | Gets kind from a derivation.
153-
topKind :: Getter Derivation Kind
154-
topKind = to f
155-
where
156-
f = \case
157-
Axiom (Judgement (_, _, k)) -> k
158-
Abstraction (Judgement (_, _, k)) _ -> k
159-
Application (Judgement (_, _, k)) _ _ -> k
160-
161252
-- | Unification monad.
162253
type Unifier a = forall effs. Member (Error InferErr) effs => Eff effs a
163254

164255
-- | Gets the variables of a type.
165256
getVariables :: Kind -> [Atom]
166257
getVariables = \case
167-
Type -> mempty
258+
KType -> mempty
168259
x :->: y -> getVariables x <> getVariables y
169260
KVar x -> [x]
170261

@@ -175,14 +266,14 @@ getVariables = \case
175266
unify :: [Constraint] -> Unifier [Substitution]
176267
unify [] = pure []
177268
unify (constraint@(Constraint (l, r)) : xs) = case l of
178-
Type -> case r of
179-
Type -> unify xs
269+
KType -> case r of
270+
KType -> unify xs
180271
(_ :->: _) -> nope constraint
181272
KVar v ->
182-
let sub = Substitution (v, Type)
273+
let sub = Substitution (v, KType)
183274
in (sub :) <$> unify (sub `substituteIn` xs)
184275
x :->: y -> case r of
185-
Type -> nope constraint
276+
KType -> nope constraint
186277
KVar v ->
187278
if v `appearsIn` l
188279
then appearsErr v l
@@ -229,7 +320,7 @@ unify (constraint@(Constraint (l, r)) : xs) = case l of
229320
-- | Applies substitutions to a kind.
230321
applySubstitution :: Substitution -> Kind -> Kind
231322
applySubstitution s@(Substitution (a, t)) k = case k of
232-
Type -> Type
323+
KType -> KType
233324
l :->: r -> applySubstitution s l :->: applySubstitution s r
234325
KVar v -> if v == a then t else k
235326

@@ -245,6 +336,7 @@ substitute s d = case d of
245336
Axiom j -> Axiom (applySubsToJudgement s j)
246337
Abstraction j dc -> Abstraction (applySubsToJudgement s j) (substitute s dc)
247338
Application j d1 d2 -> Application (applySubsToJudgement s j) (substitute s d1) (substitute s d2)
339+
Implication j dc -> Implication (applySubsToJudgement s j) (substitute s dc)
248340
where
249341
applySubsToJudgement sub (Judgement (ctx, t, k)) = Judgement (applySubstitutionCtx s ctx, t, applySubstitution sub k)
250342

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
module LambdaBuffers.Compiler.KindCheck.Judgement (Judgement (Judgement), getJudgement) where
1+
module LambdaBuffers.Compiler.KindCheck.Judgement (Judgement (Judgement), getJudgement, jType, jKind) where
22

3+
import Control.Lens (Lens', lens)
34
import LambdaBuffers.Compiler.KindCheck.Context (Context)
45
import LambdaBuffers.Compiler.KindCheck.Kind (Kind)
56
import LambdaBuffers.Compiler.KindCheck.Type (Type)
@@ -8,5 +9,17 @@ import Prettyprinter (Pretty (pretty), (<+>))
89
newtype Judgement = Judgement {getJudgement :: (Context, Type, Kind)}
910
deriving stock (Show, Eq)
1011

12+
jType :: Lens' Judgement Type
13+
jType = lens from to
14+
where
15+
from = (\(_, x, _) -> x) . getJudgement
16+
to (Judgement (c, _, k)) t = Judgement (c, t, k)
17+
18+
jKind :: Lens' Judgement Kind
19+
jKind = lens from to
20+
where
21+
from = (\(_, _, k) -> k) . getJudgement
22+
to (Judgement (c, t, _)) k = Judgement (c, t, k)
23+
1124
instance Pretty Judgement where
1225
pretty (Judgement (c, t, k)) = pretty c <> "" <> pretty t <+> ":" <+> pretty k

0 commit comments

Comments
 (0)