@@ -1040,7 +1040,7 @@ object Types {
10401040 def safe_& (that : Type )(using Context ): Type = (this , that) match {
10411041 case (TypeBounds (lo1, hi1), TypeBounds (lo2, hi2)) =>
10421042 TypeBounds (
1043- OrType .makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
1043+ OrType .makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
10441044 AndType .makeHk(hi1.stripLazyRef, hi2.stripLazyRef))
10451045 case _ =>
10461046 this & that
@@ -1151,10 +1151,11 @@ object Types {
11511151 case _ => this
11521152 }
11531153
1154- /** Widen this type and if the result contains embedded union types, replace
1154+ /** Widen this type and if the result contains embedded soft union types, replace
11551155 * them by their joins.
1156- * "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
1157- * If an embedded union is found, we first try to simplify or eliminate it by
1156+ * "Embedded" means: inside type lambdas, intersections or recursive types,
1157+ * in prefixes of refined types, or in hard union types.
1158+ * If an embedded soft union is found, we first try to simplify or eliminate it by
11581159 * re-lubbing it while allowing type parameters to be constrained further.
11591160 * Any remaining union types are replaced by their joins.
11601161 *
@@ -1168,24 +1169,22 @@ object Types {
11681169 * Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
11691170 * then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
11701171 */
1171- def widenUnion (using Context ): Type = widen match {
1172+ def widenUnion (using Context ): Type = widen match
11721173 case tp @ OrNull (tp1): OrType =>
11731174 // Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
11741175 val tp1Widen = tp1.widenUnionWithoutNull
11751176 if (tp1Widen.isRef(defn.AnyClass )) tp1Widen
11761177 else tp.derivedOrType(tp1Widen, defn.NullType )
11771178 case tp =>
11781179 tp.widenUnionWithoutNull
1179- }
11801180
1181- def widenUnionWithoutNull (using Context ): Type = widen match {
1182- case tp @ OrType (lhs, rhs) =>
1183- TypeComparer .lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true ) match {
1181+ def widenUnionWithoutNull (using Context ): Type = widen match
1182+ case tp @ OrType (lhs, rhs) if tp.isSoft =>
1183+ TypeComparer .lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true ) match
11841184 case union : OrType => union.join
11851185 case res => res
1186- }
1187- case tp @ AndType (tp1, tp2) =>
1188- tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1186+ case tp : AndOrType =>
1187+ tp.derivedAndOrType(tp.tp1.widenUnionWithoutNull, tp.tp2.widenUnionWithoutNull)
11891188 case tp : RefinedType =>
11901189 tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
11911190 case tp : RecType =>
@@ -1194,7 +1193,6 @@ object Types {
11941193 tp.derivedLambdaType(resType = tp.resType.widenUnion)
11951194 case tp =>
11961195 tp
1197- }
11981196
11991197 /** Widen all top-level singletons reachable by dealiasing
12001198 * and going to the operands of & and |.
@@ -2917,8 +2915,9 @@ object Types {
29172915
29182916 def derivedAndOrType (tp1 : Type , tp2 : Type )(using Context ) =
29192917 if ((tp1 eq this .tp1) && (tp2 eq this .tp2)) this
2920- else if (isAnd) AndType .make(tp1, tp2, checkValid = true )
2921- else OrType .make(tp1, tp2)
2918+ else this match
2919+ case tp : OrType => OrType .make(tp1, tp2, tp.isSoft)
2920+ case tp : AndType => AndType .make(tp1, tp2, checkValid = true )
29222921 }
29232922
29242923 abstract case class AndType (tp1 : Type , tp2 : Type ) extends AndOrType {
@@ -2992,6 +2991,7 @@ object Types {
29922991
29932992 abstract case class OrType (tp1 : Type , tp2 : Type ) extends AndOrType {
29942993 def isAnd : Boolean = false
2994+ def isSoft : Boolean
29952995 private var myBaseClassesPeriod : Period = Nowhere
29962996 private var myBaseClasses : List [ClassSymbol ] = _
29972997 /** Base classes of are the intersection of the operand base classes. */
@@ -3052,32 +3052,33 @@ object Types {
30523052 myWidened
30533053 }
30543054
3055- def derivedOrType (tp1 : Type , tp2 : Type )(using Context ): Type =
3056- if ((tp1 eq this .tp1) && (tp2 eq this .tp2)) this
3057- else OrType .make(tp1, tp2)
3055+ def derivedOrType (tp1 : Type , tp2 : Type , soft : Boolean = isSoft )(using Context ): Type =
3056+ if ((tp1 eq this .tp1) && (tp2 eq this .tp2) && soft == isSoft ) this
3057+ else OrType .make(tp1, tp2, soft )
30583058
3059- override def computeHash (bs : Binders ): Int = doHash(bs, tp1, tp2)
3059+ override def computeHash (bs : Binders ): Int =
3060+ doHash(bs, if isSoft then 0 else 1 , tp1, tp2)
30603061
30613062 override def eql (that : Type ): Boolean = that match {
3062- case that : OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2)
3063+ case that : OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2) && isSoft == that.isSoft
30633064 case _ => false
30643065 }
30653066 }
30663067
3067- final class CachedOrType (tp1 : Type , tp2 : Type ) extends OrType (tp1, tp2)
3068+ final class CachedOrType (tp1 : Type , tp2 : Type , override val isSoft : Boolean ) extends OrType (tp1, tp2)
30683069
30693070 object OrType {
3070- def apply (tp1 : Type , tp2 : Type )(using Context ): OrType = {
3071+ def apply (tp1 : Type , tp2 : Type , soft : Boolean )(using Context ): OrType = {
30713072 assertUnerased()
3072- unique(new CachedOrType (tp1, tp2))
3073+ unique(new CachedOrType (tp1, tp2, soft ))
30733074 }
3074- def make (tp1 : Type , tp2 : Type )(using Context ): Type =
3075+ def make (tp1 : Type , tp2 : Type , soft : Boolean )(using Context ): Type =
30753076 if (tp1 eq tp2) tp1
3076- else apply(tp1, tp2)
3077+ else apply(tp1, tp2, soft )
30773078
30783079 /** Like `make`, but also supports higher-kinded types as argument */
30793080 def makeHk (tp1 : Type , tp2 : Type )(using Context ): Type =
3080- TypeComparer .liftIfHK(tp1, tp2, OrType (_, _), makeHk, _ & _)
3081+ TypeComparer .liftIfHK(tp1, tp2, OrType (_, _, soft = true ), makeHk, _ & _)
30813082 }
30823083
30833084 /** An extractor object to pattern match against a nullable union.
@@ -3089,7 +3090,7 @@ object Types {
30893090 */
30903091 object OrNull {
30913092 def apply (tp : Type )(using Context ) =
3092- OrType (tp, defn.NullType )
3093+ OrType (tp, defn.NullType , soft = false )
30933094 def unapply (tp : Type )(using Context ): Option [Type ] =
30943095 if (ctx.explicitNulls) {
30953096 val tp1 = tp.stripNull()
@@ -3107,7 +3108,7 @@ object Types {
31073108 */
31083109 object OrUncheckedNull {
31093110 def apply (tp : Type )(using Context ) =
3110- OrType (tp, defn.UncheckedNullAliasType )
3111+ OrType (tp, defn.UncheckedNullAliasType , soft = false )
31113112 def unapply (tp : Type )(using Context ): Option [Type ] =
31123113 if (ctx.explicitNulls) {
31133114 val tp1 = tp.stripUncheckedNull
0 commit comments