@@ -27,7 +27,11 @@ sealed abstract class GadtConstraint extends Showable {
2727 /** Is `sym1` ordered to be less than `sym2`? */
2828 def isLess (sym1 : Symbol , sym2 : Symbol )(implicit ctx : Context ): Boolean
2929
30- def addEmptyBounds (sym : Symbol )(implicit ctx : Context ): Unit
30+ /** Add symbols to constraint, preserving the underlying bounds and handling inter-dependencies. */
31+ def addToConstraint (syms : List [Symbol ])(implicit ctx : Context ): Boolean
32+ def addToConstraint (sym : Symbol )(implicit ctx : Context ): Boolean = addToConstraint(sym :: Nil )
33+
34+ /** Further constrain a symbol already present in the constraint. */
3135 def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean
3236
3337 /** Is the symbol registered in the constraint?
@@ -72,7 +76,54 @@ final class ProperGadtConstraint private(
7276 subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
7377 }
7478
75- override def addEmptyBounds (sym : Symbol )(implicit ctx : Context ): Unit = tvar(sym)
79+ override def addToConstraint (params : List [Symbol ])(implicit ctx : Context ): Boolean = {
80+ import NameKinds .DepParamName
81+
82+ val poly1 = PolyType (params.map { sym => DepParamName .fresh(sym.name.toTypeName) })(
83+ pt => params.map { param =>
84+ // replace the symbols in bound type `tp` which are in dependent positions
85+ // with their internal TypeParamRefs
86+ def substDependentSyms (tp : Type , isUpper : Boolean )(implicit ctx : Context ): Type = {
87+ def loop (tp : Type ) = substDependentSyms(tp, isUpper)
88+ tp match {
89+ case tp @ AndType (tp1, tp2) if ! isUpper =>
90+ tp.derivedAndType(loop(tp1), loop(tp2))
91+ case tp @ OrType (tp1, tp2) if isUpper =>
92+ tp.derivedOrType(loop(tp1), loop(tp2))
93+ case tp : NamedType =>
94+ params.indexOf(tp.symbol) match {
95+ case - 1 =>
96+ mapping(tp.symbol) match {
97+ case tv : TypeVar => tv.origin
98+ case null => tp
99+ }
100+ case i => pt.paramRefs(i)
101+ }
102+ case tp => tp
103+ }
104+ }
105+
106+ val tb = param.info.bounds
107+ tb.derivedTypeBounds(
108+ lo = substDependentSyms(tb.lo, isUpper = false ),
109+ hi = substDependentSyms(tb.hi, isUpper = true )
110+ )
111+ },
112+ pt => defn.AnyType
113+ )
114+
115+ val tvars = (params, poly1.paramRefs).zipped.map { (sym, paramRef) =>
116+ val tv = new TypeVar (paramRef, creatorState = null )
117+ mapping = mapping.updated(sym, tv)
118+ reverseMapping = reverseMapping.updated(tv.origin, sym)
119+ tv
120+ }
121+
122+ // the replaced symbols will be stripped off the bounds by `addToConstraint` and used as orderings
123+ addToConstraint(poly1, tvars).reporting({ _ =>
124+ i " added to constraint: $params%, % \n $debugBoundsDescription"
125+ }, gadts)
126+ }
76127
77128 override def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean = {
78129 @ annotation.tailrec def stripInternalTypeVar (tp : Type ): Type = tp match {
@@ -82,16 +133,17 @@ final class ProperGadtConstraint private(
82133 case _ => tp
83134 }
84135
85- val symTvar : TypeVar = stripInternalTypeVar(tvar (sym)) match {
136+ val symTvar : TypeVar = stripInternalTypeVar(tvarOrError (sym)) match {
86137 case tv : TypeVar => tv
87138 case inst =>
88139 gadts.println(i " instantiated: $sym -> $inst" )
89140 return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst)
90141 }
91142
92143 val internalizedBound = bound match {
93- case nt : NamedType if contains(nt.symbol) =>
94- stripInternalTypeVar(tvar(nt.symbol))
144+ case nt : NamedType =>
145+ val ntTvar = mapping(nt.symbol)
146+ if (ntTvar ne null ) stripInternalTypeVar(ntTvar) else bound
95147 case _ => bound
96148 }
97149 (
@@ -119,20 +171,22 @@ final class ProperGadtConstraint private(
119171 if (isUpper) addUpperBound(symTvar.origin, bound1)
120172 else addLowerBound(symTvar.origin, bound1)
121173 }
122- ).reporting({ res =>
174+ ).reporting({ res =>
123175 val descr = if (isUpper) " upper" else " lower"
124176 val op = if (isUpper) " <:" else " >:"
125- i " adding $descr bound $sym $op $bound = $res\t ( $symTvar $op $internalizedBound ) "
177+ i " adding $descr bound $sym $op $bound = $res"
126178 }, gadts)
127179 }
128180
129181 override def isLess (sym1 : Symbol , sym2 : Symbol )(implicit ctx : Context ): Boolean =
130- constraint.isLess(tvar (sym1).origin, tvar (sym2).origin)
182+ constraint.isLess(tvarOrError (sym1).origin, tvarOrError (sym2).origin)
131183
132184 override def fullBounds (sym : Symbol )(implicit ctx : Context ): TypeBounds =
133185 mapping(sym) match {
134186 case null => null
135- case tv => fullBounds(tv.origin)
187+ case tv =>
188+ fullBounds(tv.origin)
189+ .ensuring(containsNoInternalTypes(_))
136190 }
137191
138192 override def bounds (sym : Symbol )(implicit ctx : Context ): TypeBounds = {
@@ -145,14 +199,16 @@ final class ProperGadtConstraint private(
145199 TypeAlias (reverseMapping(tpr).typeRef)
146200 case tb => tb
147201 }
148- retrieveBounds// .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
202+ retrieveBounds
203+ // .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
204+ .ensuring(containsNoInternalTypes(_))
149205 }
150206 }
151207
152208 override def contains (sym : Symbol )(implicit ctx : Context ): Boolean = mapping(sym) ne null
153209
154210 override def approximation (sym : Symbol , fromBelow : Boolean )(implicit ctx : Context ): Type = {
155- val res = approximation(tvar (sym).origin, fromBelow = fromBelow)
211+ val res = approximation(tvarOrError (sym).origin, fromBelow = fromBelow)
156212 gadts.println(i " approximating $sym ~> $res" )
157213 res
158214 }
@@ -207,36 +263,21 @@ final class ProperGadtConstraint private(
207263 case null => param
208264 }
209265
210- private [this ] def tvar (sym : Symbol )(implicit ctx : Context ): TypeVar = {
211- mapping(sym) match {
212- case tv : TypeVar =>
213- tv
214- case null =>
215- val res = {
216- import NameKinds .DepParamName
217- // For symbols standing for HK types, we need to preserve the kind information
218- // (see also usage of adaptHKvariances above)
219- // Ideally we'd always preserve the bounds,
220- // but first we need an equivalent of ConstraintHandling#addConstraint
221- // TODO: implement the above
222- val initialBounds = sym.info match {
223- case tb @ TypeBounds (_, hi) if hi.isLambdaSub => tb
224- case _ => TypeBounds .empty
225- }
226- // avoid registering the TypeVar with TyperState / TyperState#constraint
227- // - we don't want TyperState instantiating these TypeVars
228- // - we don't want TypeComparer constraining these TypeVars
229- val poly = PolyType (DepParamName .fresh(sym.name.toTypeName) :: Nil )(
230- pt => initialBounds :: Nil ,
231- pt => defn.AnyType )
232- new TypeVar (poly.paramRefs.head, creatorState = null )
233- }
234- gadts.println(i " GADTMap: created tvar $sym -> $res" )
235- constraint = constraint.add(res.origin.binder, res :: Nil )
236- mapping = mapping.updated(sym, res)
237- reverseMapping = reverseMapping.updated(res.origin, sym)
238- res
239- }
266+ private [this ] def tvarOrError (sym : Symbol )(implicit ctx : Context ): TypeVar =
267+ mapping(sym).ensuring(_ ne null , i " not a constrainable symbol: $sym" )
268+
269+ private [this ] def containsNoInternalTypes (
270+ tp : Type ,
271+ acc : TypeAccumulator [Boolean ] = null
272+ )(implicit ctx : Context ): Boolean = tp match {
273+ case tpr : TypeParamRef => ! reverseMapping.contains(tpr)
274+ case tv : TypeVar => ! reverseMapping.contains(tv.origin)
275+ case tp =>
276+ (if (acc ne null ) acc else new ContainsNoInternalTypesAccumulator ()).foldOver(true , tp)
277+ }
278+
279+ private [this ] class ContainsNoInternalTypesAccumulator (implicit ctx : Context ) extends TypeAccumulator [Boolean ] {
280+ override def apply (x : Boolean , tp : Type ): Boolean = x && containsNoInternalTypes(tp)
240281 }
241282
242283 // ---- Debug ------------------------------------------------------------
@@ -266,7 +307,7 @@ final class ProperGadtConstraint private(
266307
267308 override def contains (sym : Symbol )(implicit ctx : Context ) = false
268309
269- override def addEmptyBounds ( sym : Symbol )(implicit ctx : Context ): Unit = unsupported(" EmptyGadtConstraint.addEmptyBounds " )
310+ override def addToConstraint ( params : List [ Symbol ] )(implicit ctx : Context ): Boolean = unsupported(" EmptyGadtConstraint.addToConstraint " )
270311 override def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean = unsupported(" EmptyGadtConstraint.addBound" )
271312
272313 override def approximation (sym : Symbol , fromBelow : Boolean )(implicit ctx : Context ): Type = unsupported(" EmptyGadtConstraint.approximation" )
0 commit comments