Skip to content

Commit 93fbce0

Browse files
committed
[Strict memory safety] Nested types are safe/unsafe independent of their enclosing type
When determining whether a nested type is safe, don't consider whether its enclosing type is safe. They're independent. (cherry picked from commit 8ec52c8)
1 parent 469b303 commit 93fbce0

File tree

3 files changed

+124
-16
lines changed

3 files changed

+124
-16
lines changed

lib/AST/ASTContext.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4190,6 +4190,55 @@ void UnboundGenericType::Profile(llvm::FoldingSetNodeID &ID,
41904190
ID.AddPointer(Parent.getPointer());
41914191
}
41924192

4193+
/// The safety of a parent type does not have an impact on a nested type within
4194+
/// it. This produces the recursive properties of a given type that should
4195+
/// be propagated to a nested type, which won't include any "IsUnsafe" bit
4196+
/// determined based on the declaration itself.
4197+
static RecursiveTypeProperties getRecursivePropertiesAsParent(Type type) {
4198+
if (!type)
4199+
return RecursiveTypeProperties();
4200+
4201+
// We only need to do anything interesting at all for unsafe types.
4202+
auto properties = type->getRecursiveProperties();
4203+
if (!properties.isUnsafe())
4204+
return properties;
4205+
4206+
if (auto nominal = type->getAnyNominal()) {
4207+
// If the nominal wasn't itself unsafe, then we got the unsafety from
4208+
// something else (e.g., a generic argument), so it won't change.
4209+
if (nominal->getExplicitSafety() != ExplicitSafety::Unsafe)
4210+
return properties;
4211+
}
4212+
4213+
// Drop the "unsafe" bit. We have to recompute it without considering the
4214+
// enclosing nominal type.
4215+
properties = RecursiveTypeProperties(
4216+
properties.getBits() & ~static_cast<unsigned>(RecursiveTypeProperties::IsUnsafe));
4217+
4218+
// Check generic arguments of parent types.
4219+
while (type) {
4220+
// Merge from the generic arguments.
4221+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
4222+
for (auto genericArg : boundGeneric->getGenericArgs())
4223+
properties |= genericArg->getRecursiveProperties();
4224+
}
4225+
4226+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>()) {
4227+
type = nominalOrBound->getParent();
4228+
continue;
4229+
}
4230+
4231+
if (auto unbound = type->getAs<UnboundGenericType>()) {
4232+
type = unbound->getParent();
4233+
continue;
4234+
}
4235+
4236+
break;
4237+
};
4238+
4239+
return properties;
4240+
}
4241+
41934242
UnboundGenericType *UnboundGenericType::
41944243
get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41954244
llvm::FoldingSetNodeID ID;
@@ -4198,7 +4247,7 @@ get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
41984247
RecursiveTypeProperties properties;
41994248
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42004249
properties |= RecursiveTypeProperties::IsUnsafe;
4201-
if (Parent) properties |= Parent->getRecursiveProperties();
4250+
properties |= getRecursivePropertiesAsParent(Parent);
42024251

42034252
auto arena = getArena(properties);
42044253

@@ -4252,7 +4301,7 @@ BoundGenericType *BoundGenericType::get(NominalTypeDecl *TheDecl,
42524301
RecursiveTypeProperties properties;
42534302
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
42544303
properties |= RecursiveTypeProperties::IsUnsafe;
4255-
if (Parent) properties |= Parent->getRecursiveProperties();
4304+
properties |= getRecursivePropertiesAsParent(Parent);
42564305
for (Type Arg : GenericArgs) {
42574306
properties |= Arg->getRecursiveProperties();
42584307
}
@@ -4335,7 +4384,7 @@ EnumType *EnumType::get(EnumDecl *D, Type Parent, const ASTContext &C) {
43354384
RecursiveTypeProperties properties;
43364385
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43374386
properties |= RecursiveTypeProperties::IsUnsafe;
4338-
if (Parent) properties |= Parent->getRecursiveProperties();
4387+
properties |= getRecursivePropertiesAsParent(Parent);
43394388
auto arena = getArena(properties);
43404389

43414390
auto *&known = C.getImpl().getArena(arena).EnumTypes[{D, Parent}];
@@ -4353,7 +4402,7 @@ StructType *StructType::get(StructDecl *D, Type Parent, const ASTContext &C) {
43534402
RecursiveTypeProperties properties;
43544403
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43554404
properties |= RecursiveTypeProperties::IsUnsafe;
4356-
if (Parent) properties |= Parent->getRecursiveProperties();
4405+
properties |= getRecursivePropertiesAsParent(Parent);
43574406
auto arena = getArena(properties);
43584407

43594408
auto *&known = C.getImpl().getArena(arena).StructTypes[{D, Parent}];
@@ -4371,7 +4420,7 @@ ClassType *ClassType::get(ClassDecl *D, Type Parent, const ASTContext &C) {
43714420
RecursiveTypeProperties properties;
43724421
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
43734422
properties |= RecursiveTypeProperties::IsUnsafe;
4374-
if (Parent) properties |= Parent->getRecursiveProperties();
4423+
properties |= getRecursivePropertiesAsParent(Parent);
43754424
auto arena = getArena(properties);
43764425

43774426
auto *&known = C.getImpl().getArena(arena).ClassTypes[{D, Parent}];
@@ -5538,7 +5587,7 @@ ProtocolType *ProtocolType::get(ProtocolDecl *D, Type Parent,
55385587
RecursiveTypeProperties properties;
55395588
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
55405589
properties |= RecursiveTypeProperties::IsUnsafe;
5541-
if (Parent) properties |= Parent->getRecursiveProperties();
5590+
properties |= getRecursivePropertiesAsParent(Parent);
55425591
auto arena = getArena(properties);
55435592

55445593
auto *&known = C.getImpl().getArena(arena).ProtocolTypes[{D, Parent}];

lib/Sema/TypeCheckUnsafe.cpp

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,12 @@ bool swift::enumerateUnsafeUses(ArrayRef<ProtocolConformanceRef> conformances,
330330
bool swift::enumerateUnsafeUses(SubstitutionMap subs,
331331
SourceLoc loc,
332332
llvm::function_ref<bool(UnsafeUse)> fn) {
333-
// FIXME: Check replacement types?
333+
// Replacement types.
334+
for (auto replacementType : subs.getReplacementTypes()) {
335+
if (replacementType->isUnsafe() &&
336+
fn(UnsafeUse::forReferenceToUnsafe(nullptr, false, replacementType, loc)))
337+
return true;
338+
}
334339

335340
// Check conformances.
336341
if (enumerateUnsafeUses(subs.getConformances(), loc, fn))
@@ -379,24 +384,69 @@ void swift::diagnoseUnsafeType(ASTContext &ctx, SourceLoc loc, Type type,
379384
return;
380385

381386
// Look for a specific @unsafe nominal type along the way.
382-
auto findSpecificUnsafeType = [](Type type) {
387+
class Walker : public TypeWalker {
388+
public:
383389
Type specificType;
384-
(void)type.findIf([&specificType](Type type) {
390+
391+
Action walkToTypePre(Type type) override {
392+
if (specificType)
393+
return Action::Stop;
394+
395+
// If this refers to a nominal type that is @unsafe, store that.
385396
if (auto typeDecl = type->getAnyNominal()) {
386397
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
387398
specificType = type;
388-
return false;
399+
return Action::Stop;
389400
}
390401
}
391402

392-
return false;
393-
});
394-
return specificType;
403+
// Do not recurse into nominal types, because we do not want to visit
404+
// their "parent" types.
405+
if (isa<NominalOrBoundGenericNominalType>(type.getPointer()) ||
406+
isa<UnboundGenericType>(type.getPointer())) {
407+
// Recurse into the generic arguments. This operation is recursive,
408+
// because we also need to see the generic arguments of parent types.
409+
walkGenericArguments(type);
410+
411+
return Action::SkipNode;
412+
}
413+
414+
return Action::Continue;
415+
}
416+
417+
private:
418+
/// Recursively walk the generic arguments of this type and its parent
419+
/// types.
420+
void walkGenericArguments(Type type) {
421+
if (!type)
422+
return;
423+
424+
// Walk the generic arguments.
425+
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
426+
for (auto genericArg : boundGeneric->getGenericArgs())
427+
genericArg.walk(*this);
428+
}
429+
430+
if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>())
431+
return walkGenericArguments(nominalOrBound->getParent());
432+
433+
if (auto unbound = type->getAs<UnboundGenericType>())
434+
return walkGenericArguments(unbound->getParent());
435+
}
395436
};
396437

397-
Type specificType = findSpecificUnsafeType(type);
398-
if (!specificType)
399-
specificType = findSpecificUnsafeType(type->getCanonicalType());
438+
// Look for a canonical unsafe type.
439+
Walker walker;
440+
type->getCanonicalType().walk(walker);
441+
Type specificType = walker.specificType;
442+
443+
// Look for an unsafe type in the non-canonical type, which is a better answer
444+
// if we can find it.
445+
walker.specificType = Type();
446+
type.walk(walker);
447+
if (specificType && walker.specificType &&
448+
specificType->isEqual(walker.specificType))
449+
specificType = walker.specificType;
400450

401451
diagnose(specificType ? specificType : type);
402452
}

test/Unsafe/safe.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,15 @@ struct UnsafeContainingUnspecified {
277277
typealias A = Int
278278

279279
func getA() -> A { 0 }
280+
281+
@safe
282+
struct Y {
283+
var value: Int
284+
}
285+
286+
func f() {
287+
_ = Y(value: 5)
288+
}
280289
}
281290

282291

0 commit comments

Comments
 (0)