@@ -1151,10 +1151,11 @@ object Types {
11511151 case _ => this
11521152 }
11531153
1154- /** Widen this type and if the result contains embedded union types, replace
1154+ /** Widen this type and if the result contains embedded soft union types, replace
11551155 * them by their joins.
1156- * "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
1157- * If an embedded union is found, we first try to simplify or eliminate it by
1156+ * "Embedded" means: inside type lambdas, intersections or recursive types,
1157+ * in prefixes of refined types, or in hard union types.
1158+ * If an embedded soft union is found, we first try to simplify or eliminate it by
11581159 * re-lubbing it while allowing type parameters to be constrained further.
11591160 * Any remaining union types are replaced by their joins.
11601161 *
@@ -1165,36 +1166,78 @@ object Types {
11651166 * is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
11661167 * instead of `ArrayBuffer[? >: Int | A <: Int & A]`
11671168 *
1169+ * Hard unions inside soft ones are treated specially. For illustration assume we
1170+ * want to widen the type `(A | C) \/ (B | C)` where `\/` means soft union and `|`
1171+ * means hard union. In that case, the hard unions `A | C` and `B | C` are treated
1172+ * in an asymmetric way. Only the first parts `A` and `B` are joined and the rest
1173+ * is added again with a hard union to the result. So
1174+ *
1175+ * widenUnion[ (A | C) \/ (B | C) ]
1176+ * = widenUnion[ A \/ B ] | C | C
1177+ * = D | C | C
1178+ * = D | C
1179+ *
1180+ * In general, If a hard union A | B_1 | ... | B_n is part of of a soft union,
1181+ * only A forms part of the join, and B_1, ..., B_n are pushed out, just `C` is
1182+ * pushed out above. All types that are pushed out are recombined with the result
1183+ * of the join with a lub, but that lub yields again a hard union, not a soft one.
1184+ *
11681185 * Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
11691186 * then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
11701187 */
1171- def widenUnion (using Context ): Type = widen match {
1188+ def widenUnion (using Context ): Type = widen. match {
11721189 case tp @ OrNull (tp1): OrType =>
11731190 // Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
11741191 val tp1Widen = tp1.widenUnionWithoutNull
11751192 if (tp1Widen.isRef(defn.AnyClass )) tp1Widen
11761193 else tp.derivedOrType(tp1Widen, defn.NullType )
11771194 case tp =>
11781195 tp.widenUnionWithoutNull
1179- }
1196+ }.reporting( i " widenUnion( $this ) = $result " )
11801197
1181- def widenUnionWithoutNull (using Context ): Type = widen match {
1182- case tp @ OrType (lhs, rhs) =>
1183- TypeComparer .lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true ) match {
1184- case union : OrType => union.join
1185- case res => res
1186- }
1187- case tp @ AndType (tp1, tp2) =>
1188- tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1189- case tp : RefinedType =>
1190- tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1191- case tp : RecType =>
1192- tp.rebind(tp.parent.widenUnion)
1193- case tp : HKTypeLambda =>
1194- tp.derivedLambdaType(resType = tp.resType.widenUnion)
1195- case tp =>
1196- tp
1197- }
1198+ def widenUnionWithoutNull (using Context ): Type =
1199+
1200+ // Split hard union `A | B1 | ... | Bn` into leftmost part `A` and list of
1201+ // pushed out parts `B1, ..., Bn`.
1202+ def splitAlts (tp : Type , follow : List [Type ]): (Type , List [Type ]) = tp match
1203+ case tp as OrType (lhs, rhs) if ! tp.isSoft =>
1204+ splitAlts(lhs, rhs :: follow)
1205+ case _ =>
1206+ (tp, follow)
1207+
1208+ // Convert any soft unions in result of lub to hard ones */
1209+ def harden (tp : Type ): Type = tp match
1210+ case tp as OrType (tp1, tp2) if tp.isSoft =>
1211+ OrType (harden(tp1), harden(tp2), soft = false )
1212+ case _ =>
1213+ tp
1214+
1215+ def recombine (tp1 : Type , tp2 : Type ) = harden(TypeComparer .lub(tp1, tp2))
1216+
1217+ widen match
1218+ case tp @ OrType (lhs, rhs) =>
1219+ if tp.isSoft then
1220+ val (lhsCore, lhsExtras) = splitAlts(lhs.widenUnionWithoutNull, Nil )
1221+ val (rhsCore, rhsExtras) = splitAlts(rhs.widenUnionWithoutNull, Nil )
1222+ val core = TypeComparer .lub(lhsCore, rhsCore, canConstrain = true ) match
1223+ case union : OrType => union.join
1224+ case res => res
1225+ rhsExtras.foldLeft(lhsExtras.foldLeft(core)(recombine))(recombine)
1226+ else
1227+ val lhs1 = lhs.widenUnionWithoutNull
1228+ val rhs1 = rhs.widenUnionWithoutNull
1229+ if (lhs1 eq lhs) && (rhs1 eq rhs) then tp else recombine(lhs1, rhs1)
1230+ case tp @ AndType (tp1, tp2) =>
1231+ tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1232+ case tp : RefinedType =>
1233+ tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1234+ case tp : RecType =>
1235+ tp.rebind(tp.parent.widenUnion)
1236+ case tp : HKTypeLambda =>
1237+ tp.derivedLambdaType(resType = tp.resType.widenUnion)
1238+ case tp =>
1239+ tp
1240+ end widenUnionWithoutNull
11981241
11991242 /** Widen all top-level singletons reachable by dealiasing
12001243 * and going to the operands of & and |.
@@ -3054,9 +3097,9 @@ object Types {
30543097 myWidened
30553098 }
30563099
3057- def derivedOrType (tp1 : Type , tp2 : Type )(using Context ): Type =
3058- if ((tp1 eq this .tp1) && (tp2 eq this .tp2)) this
3059- else OrType .make(tp1, tp2, isSoft )
3100+ def derivedOrType (tp1 : Type , tp2 : Type , soft : Boolean = isSoft )(using Context ): Type =
3101+ if ((tp1 eq this .tp1) && (tp2 eq this .tp2) && soft == isSoft ) this
3102+ else OrType .make(tp1, tp2, soft )
30603103
30613104 override def computeHash (bs : Binders ): Int =
30623105 doHash(bs, if isSoft then 0 else 1 , tp1, tp2)
0 commit comments