@@ -4439,6 +4439,13 @@ ConstraintSystem::matchTypesBindTypeVar(
44394439 : getTypeMatchFailure(locator);
44404440 }
44414441
4442+ if (typeVar->getImpl().isKeyPathType()) {
4443+ if (flags.contains(TMF_BindingTypeVariable))
4444+ return resolveKeyPath(typeVar, type, locator)
4445+ ? getTypeMatchSuccess()
4446+ : getTypeMatchFailure(locator);
4447+ }
4448+
44424449 assignFixedType(typeVar, type, /*updateState=*/true,
44434450 /*notifyInference=*/!flags.contains(TMF_BindingTypeVariable));
44444451
@@ -4630,6 +4637,13 @@ repairViaBridgingCast(ConstraintSystem &cs, Type fromType, Type toType,
46304637 return true;
46314638}
46324639
4640+ /// Return tuple of type and number of optionals on that type.
4641+ static std::pair<Type, unsigned> getObjectTypeAndNumUnwraps(Type type) {
4642+ SmallVector<Type, 2> optionals;
4643+ Type objType = type->lookThroughAllOptionalTypes(optionals);
4644+ return std::make_pair(objType, optionals.size());
4645+ }
4646+
46334647static bool
46344648repairViaOptionalUnwrap(ConstraintSystem &cs, Type fromType, Type toType,
46354649 ConstraintKind matchKind,
@@ -4747,17 +4761,11 @@ repairViaOptionalUnwrap(ConstraintSystem &cs, Type fromType, Type toType,
47474761 }
47484762 }
47494763
4750- auto getObjectTypeAndUnwraps = [](Type type) -> std::pair<Type, unsigned> {
4751- SmallVector<Type, 2> optionals;
4752- Type objType = type->lookThroughAllOptionalTypes(optionals);
4753- return std::make_pair(objType, optionals.size());
4754- };
4755-
47564764 Type fromObjectType, toObjectType;
47574765 unsigned fromUnwraps, toUnwraps;
47584766
4759- std::tie(fromObjectType, fromUnwraps) = getObjectTypeAndUnwraps (fromType);
4760- std::tie(toObjectType, toUnwraps) = getObjectTypeAndUnwraps (toType);
4767+ std::tie(fromObjectType, fromUnwraps) = getObjectTypeAndNumUnwraps (fromType);
4768+ std::tie(toObjectType, toUnwraps) = getObjectTypeAndNumUnwraps (toType);
47614769
47624770 // Since equality is symmetric and it decays into a `Bind`, eagerly
47634771 // unwrapping optionals from either side might be incorrect since
@@ -6491,6 +6499,19 @@ bool ConstraintSystem::repairFailures(
64916499 if (!fromType || !toType)
64926500 break;
64936501
6502+ Type fromObjectType, toObjectType;
6503+ unsigned fromUnwraps, toUnwraps;
6504+
6505+ std::tie(fromObjectType, fromUnwraps) = getObjectTypeAndNumUnwraps(lhs);
6506+ std::tie(toObjectType, toUnwraps) = getObjectTypeAndNumUnwraps(rhs);
6507+
6508+ // If the bound contextual type is more optional than the binding type, then
6509+ // propogate binding type to contextual type and attempt to solve.
6510+ if (fromUnwraps < toUnwraps) {
6511+ (void)matchTypes(fromObjectType, toObjectType, ConstraintKind::Bind,
6512+ TMF_ApplyingFix, locator);
6513+ }
6514+
64946515 // Drop both `GenericType` elements.
64956516 path.pop_back();
64966517 path.pop_back();
@@ -6704,6 +6725,15 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
67046725 return getTypeMatchSuccess();
67056726 }
67066727
6728+ // If type variable represents a key path value type, defer binding it to
6729+ // contextual type in diagnostic mode. We want it to be bound from the
6730+ // last key path component to help with diagnostics.
6731+ if (shouldAttemptFixes()) {
6732+ if (typeVar1 && typeVar1->getImpl().isKeyPathValue() &&
6733+ !flags.contains(TMF_BindingTypeVariable))
6734+ return formUnsolvedResult();
6735+ }
6736+
67076737 assert((type1->is<TypeVariableType>() != type2->is<TypeVariableType>()) &&
67086738 "Expected a type variable and a non type variable!");
67096739
@@ -11542,6 +11572,31 @@ bool ConstraintSystem::resolveClosure(TypeVariableType *typeVar,
1154211572 return !generateConstraints(AnyFunctionRef{closure}, closure->getBody());
1154311573}
1154411574
11575+ bool ConstraintSystem::resolveKeyPath(TypeVariableType *typeVar,
11576+ Type contextualType,
11577+ ConstraintLocatorBuilder locator) {
11578+ auto *keyPathLocator = typeVar->getImpl().getLocator();
11579+ auto *keyPath = castToExpr<KeyPathExpr>(keyPathLocator->getAnchor());
11580+ if (keyPath->hasSingleInvalidComponent()) {
11581+ assignFixedType(typeVar, contextualType);
11582+ return true;
11583+ }
11584+ if (auto *BGT = contextualType->getAs<BoundGenericType>()) {
11585+ auto args = BGT->getGenericArgs();
11586+ if (isKnownKeyPathType(contextualType) && args.size() >= 1) {
11587+ auto root = BGT->getGenericArgs()[0];
11588+
11589+ auto *keyPathValueTV = getKeyPathValueType(keyPath);
11590+ contextualType = BoundGenericType::get(
11591+ args.size() == 1 ? getASTContext().getKeyPathDecl() : BGT->getDecl(),
11592+ /*parent=*/Type(), {root, keyPathValueTV});
11593+ }
11594+ }
11595+
11596+ assignFixedType(typeVar, contextualType);
11597+ return true;
11598+ }
11599+
1154511600bool ConstraintSystem::resolvePackExpansion(TypeVariableType *typeVar,
1154611601 Type contextualType) {
1154711602 auto *locator = typeVar->getImpl().getLocator();
0 commit comments