@@ -210,6 +210,115 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType,
210210 result.push_back ({req, loc, /* wasInferred=*/ false });
211211}
212212
213+ namespace {
214+
215+ // / AST walker that infers requirements from type representations.
216+ struct InferRequirementsWalker : public TypeWalker {
217+ ModuleDecl *module ;
218+ SmallVector<Requirement, 2 > reqs;
219+
220+ explicit InferRequirementsWalker (ModuleDecl *module ) : module(module ) {}
221+
222+ Action walkToTypePre (Type ty) override {
223+ // Unbound generic types are the result of recovered-but-invalid code, and
224+ // don't have enough info to do any useful substitutions.
225+ if (ty->is <UnboundGenericType>())
226+ return Action::Stop;
227+
228+ return Action::Continue;
229+ }
230+
231+ Action walkToTypePost (Type ty) override {
232+ // Infer from generic typealiases.
233+ if (auto typeAlias = dyn_cast<TypeAliasType>(ty.getPointer ())) {
234+ auto decl = typeAlias->getDecl ();
235+ auto subMap = typeAlias->getSubstitutionMap ();
236+ for (const auto &rawReq : decl->getGenericSignature ().getRequirements ()) {
237+ if (auto req = rawReq.subst (subMap))
238+ desugarRequirement (*req, reqs);
239+ }
240+
241+ return Action::Continue;
242+ }
243+
244+ // Infer requirements from `@differentiable` function types.
245+ // For all non-`@noDerivative` parameter and result types:
246+ // - `@differentiable`, `@differentiable(_forward)`, or
247+ // `@differentiable(reverse)`: add `T: Differentiable` requirement.
248+ // - `@differentiable(_linear)`: add
249+ // `T: Differentiable`, `T == T.TangentVector` requirements.
250+ if (auto *fnTy = ty->getAs <AnyFunctionType>()) {
251+ auto &ctx = module ->getASTContext ();
252+ auto *differentiableProtocol =
253+ ctx.getProtocol (KnownProtocolKind::Differentiable);
254+ if (differentiableProtocol && fnTy->isDifferentiable ()) {
255+ auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
256+ Requirement req (RequirementKind::Conformance, type,
257+ protocol->getDeclaredInterfaceType ());
258+ desugarRequirement (req, reqs);
259+ };
260+ auto addSameTypeConstraint = [&](Type firstType,
261+ AssociatedTypeDecl *assocType) {
262+ auto *protocol = assocType->getProtocol ();
263+ auto *module = protocol->getParentModule ();
264+ auto conf = module ->lookupConformance (firstType, protocol);
265+ auto secondType = conf.getAssociatedType (
266+ firstType, assocType->getDeclaredInterfaceType ());
267+ Requirement req (RequirementKind::SameType, firstType, secondType);
268+ desugarRequirement (req, reqs);
269+ };
270+ auto *tangentVectorAssocType =
271+ differentiableProtocol->getAssociatedType (ctx.Id_TangentVector );
272+ auto addRequirements = [&](Type type, bool isLinear) {
273+ addConformanceConstraint (type, differentiableProtocol);
274+ if (isLinear)
275+ addSameTypeConstraint (type, tangentVectorAssocType);
276+ };
277+ auto constrainParametersAndResult = [&](bool isLinear) {
278+ for (auto ¶m : fnTy->getParams ())
279+ if (!param.isNoDerivative ())
280+ addRequirements (param.getPlainType (), isLinear);
281+ addRequirements (fnTy->getResult (), isLinear);
282+ };
283+ // Add requirements.
284+ constrainParametersAndResult (fnTy->getDifferentiabilityKind () ==
285+ DifferentiabilityKind::Linear);
286+ }
287+ }
288+
289+ if (!ty->isSpecialized ())
290+ return Action::Continue;
291+
292+ // Infer from generic nominal types.
293+ auto decl = ty->getAnyNominal ();
294+ if (!decl) return Action::Continue;
295+
296+ // FIXME: The GSB and the request evaluator both detect a cycle here if we
297+ // force a recursive generic signature. We should look into moving cycle
298+ // detection into the generic signature request(s) - see rdar://55263708
299+ if (!decl->hasComputedGenericSignature ())
300+ return Action::Continue;
301+
302+ auto genericSig = decl->getGenericSignature ();
303+ if (!genericSig)
304+ return Action::Continue;
305+
306+ // / Retrieve the substitution.
307+ auto subMap = ty->getContextSubstitutionMap (module , decl);
308+
309+ // Handle the requirements.
310+ // FIXME: Inaccurate TypeReprs.
311+ for (const auto &rawReq : genericSig.getRequirements ()) {
312+ if (auto req = rawReq.subst (subMap))
313+ desugarRequirement (*req, reqs);
314+ }
315+
316+ return Action::Continue;
317+ }
318+ };
319+
320+ }
321+
213322// / Infer requirements from applications of BoundGenericTypes to type
214323// / parameters. For example, given a function declaration
215324// /
@@ -220,7 +329,14 @@ static void realizeTypeRequirement(Type subjectType, Type constraintType,
220329void swift::rewriting::inferRequirements (
221330 Type type, SourceLoc loc, ModuleDecl *module ,
222331 SmallVectorImpl<StructuralRequirement> &result) {
223- // FIXME: Implement
332+ if (!type)
333+ return ;
334+
335+ InferRequirementsWalker walker (module );
336+ type.walk (walker);
337+
338+ for (const auto &req : walker.reqs )
339+ result.push_back ({req, loc, /* wasInferred=*/ true });
224340}
225341
226342// / Desugar a requirement and perform requirement inference if requested
0 commit comments