Skip to content

Commit c278717

Browse files
committed
[CSSimplify] Avoid simplifying dependent members until pack expansions in the based are bound
Dependent members cannot be simplified if base type contains unresolved pack expansion type variables because they don't give enough information to substitution logic to form a correct type. For example: ``` protocol P { associatedtype V } struct S<each T> : P { typealias V = (repeat (each T)?) } ``` If pack expansion is represented as `$T1` and its pattern is `$T2`, a reference to `V` would get a type `S<Pack{$T}>.V` and simplified version would be `Optional<Pack{$T1}>` instead of `Pack{repeat Optional<$T2>}` because `$T1` is treated as a substitution for `each T` until bound. Resolves: rdar://161207705
1 parent 9953b1e commit c278717

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

include/swift/AST/Types.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,12 +850,11 @@ class alignas(1 << TypeAlignInBits) TypeBase
850850
/// type variables referenced by this type.
851851
void getTypeVariables(SmallPtrSetImpl<TypeVariableType *> &typeVariables);
852852

853-
private:
853+
public:
854854
/// If the receiver is a `DependentMemberType`, returns its root. Otherwise,
855855
/// returns the receiver.
856856
Type getDependentMemberRoot();
857857

858-
public:
859858
/// Determine whether this type is a type parameter, which is either a
860859
/// GenericTypeParamType or a DependentMemberType.
861860
///

lib/Sema/CSSimplify.cpp

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7099,6 +7099,20 @@ static bool isTupleWithUnresolvedPackExpansion(Type type) {
70997099
return false;
71007100
}
71017101

7102+
static bool isDependentMemberTypeWithBaseThatContainsUnresolvedPackExpansions(
7103+
ConstraintSystem &cs, Type type) {
7104+
if (!type->is<DependentMemberType>())
7105+
return false;
7106+
7107+
auto baseTy = cs.getFixedTypeRecursive(type->getDependentMemberRoot(),
7108+
/*wantRValue=*/true);
7109+
llvm::SmallPtrSet<TypeVariableType *, 2> typeVars;
7110+
baseTy->getTypeVariables(typeVars);
7111+
return llvm::any_of(typeVars, [](const TypeVariableType *typeVar) {
7112+
return typeVar->getImpl().isPackExpansion();
7113+
});
7114+
}
7115+
71027116
ConstraintSystem::TypeMatchResult
71037117
ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
71047118
TypeMatchOptions flags,
@@ -7135,7 +7149,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
71357149
//
71367150
// along any unsolved path. No other returns should produce
71377151
// SolutionKind::Unsolved or inspect TMF_GenerateConstraints.
7138-
auto formUnsolvedResult = [&] {
7152+
auto formUnsolvedResult = [&](bool useOriginalTypes = false) {
71397153
// If we're supposed to generate constraints (i.e., this is a
71407154
// newly-generated constraint), do so now.
71417155
if (flags.contains(TMF_GenerateConstraints)) {
@@ -7144,8 +7158,13 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
71447158
// this new constraint will be solved at a later point.
71457159
// Obviously, this must not happen at the top level, or the
71467160
// algorithm would not terminate.
7147-
addUnsolvedConstraint(Constraint::create(*this, kind, type1, type2,
7148-
getConstraintLocator(locator)));
7161+
if (useOriginalTypes) {
7162+
addUnsolvedConstraint(Constraint::create(
7163+
*this, kind, origType1, origType2, getConstraintLocator(locator)));
7164+
} else {
7165+
addUnsolvedConstraint(Constraint::create(
7166+
*this, kind, type1, type2, getConstraintLocator(locator)));
7167+
}
71497168
return getTypeMatchSuccess();
71507169
}
71517170

@@ -7396,6 +7415,29 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
73967415
}
73977416
}
73987417

7418+
// Dependent members cannot be simplified if base type contains unresolved
7419+
// pack expansion type variables because they don't give enough information
7420+
// to substitution logic to form a correct type. For example:
7421+
//
7422+
// ```
7423+
// protocol P { associatedtype V }
7424+
// struct S<each T> : P { typealias V = (repeat (each T)?) }
7425+
// ```
7426+
//
7427+
// If pack expansion is represented as `$T1` and its pattern is `$T2`, a
7428+
// reference to `V` would get a type `S<Pack{$T}>.V` and simplified version
7429+
// would be `Optional<Pack{$T1}>` instead of `Pack{repeat Optional<$T2>}`
7430+
// because `$T1` is treated as a substitution for `each T` until bound.
7431+
if (isDependentMemberTypeWithBaseThatContainsUnresolvedPackExpansions(
7432+
*this, origType1) ||
7433+
isDependentMemberTypeWithBaseThatContainsUnresolvedPackExpansions(
7434+
*this, origType2)) {
7435+
// It's important to preserve the original types here because any attempt
7436+
// at simplification or canonicalization wouldn't produce a correct type
7437+
// util pack expansion type variables are bound.
7438+
return formUnsolvedResult(/*useOriginalTypes=*/true);
7439+
}
7440+
73997441
llvm::SmallVector<RestrictionOrFix, 4> conversionsOrFixes;
74007442

74017443
// Decompose parallel structure.

test/Constraints/pack-expansion-expressions.swift

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ func test_pack_expansion_specialization(tuple: (Int, String, Float)) {
360360
}
361361

362362
// rdar://107280056 - "Ambiguous without more context" with opaque return type + variadics
363-
protocol Q {
363+
protocol Q<B> {
364364
associatedtype B
365365
}
366366

@@ -815,3 +815,21 @@ func testPackToScalarShortFormConstructor() {
815815
S(repeat each xs) // expected-error {{cannot pass value pack expansion to non-pack parameter of type 'Int'}}
816816
}
817817
}
818+
819+
820+
func test_dependent_members() {
821+
struct Variadic<each T>: Q {
822+
typealias B = (repeat (each T)?)
823+
824+
init(_: repeat each T) {}
825+
static func f(_: repeat each T) -> Self {}
826+
}
827+
828+
func test_init<C1, C2>(_ c1: C1, _ c2: C2) -> some Q<(C1?, C2?)> {
829+
return Variadic(c1, c2) // Ok
830+
}
831+
832+
func test_static<C1, C2>(_ c1: C1, _ c2: C2) -> some Q<(C1?, C2?)> {
833+
return Variadic.f(c1, c2) // Ok
834+
}
835+
}

0 commit comments

Comments
 (0)