@@ -421,3 +421,128 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
421421 }
422422 }
423423}
424+
425+ bool swift::operator ==(const TangentPropertyInfo::Error &lhs,
426+ const TangentPropertyInfo::Error &rhs) {
427+ if (lhs.kind != rhs.kind )
428+ return false ;
429+ switch (lhs.kind ) {
430+ case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
431+ case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
432+ case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
433+ case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
434+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
435+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
436+ return true ;
437+ case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
438+ return lhs.getType ()->isEqual (rhs.getType ());
439+ }
440+ }
441+
442+ void swift::simple_display (llvm::raw_ostream &os, TangentPropertyInfo info) {
443+ os << " { " ;
444+ os << " tangent property: "
445+ << (info.tangentProperty ? info.tangentProperty ->printRef () : " null" );
446+ if (info.error ) {
447+ os << " , error: " ;
448+ switch (info.error ->kind ) {
449+ case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
450+ os << " '@noDerivative' original property has no tangent property" ;
451+ break ;
452+ case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
453+ os << " nominal parent does not conform to 'Differentiable'" ;
454+ break ;
455+ case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
456+ os << " original property type does not conform to 'Differentiable'" ;
457+ break ;
458+ case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
459+ os << " 'TangentVector' type is not a struct" ;
460+ break ;
461+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
462+ os << " 'TangentVector' struct does not have stored property with the "
463+ " same name as the original property" ;
464+ break ;
465+ case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
466+ os << " tangent property's type is not equal to the original property's "
467+ " 'TangentVector' type" ;
468+ break ;
469+ case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
470+ os << " 'TangentVector' property '" << info.tangentProperty ->getName ()
471+ << " ' is not a stored property" ;
472+ break ;
473+ }
474+ }
475+ os << " }" ;
476+ }
477+
478+ TangentPropertyInfo
479+ TangentStoredPropertyRequest::evaluate (Evaluator &evaluator,
480+ VarDecl *originalField) const {
481+ assert (originalField->hasStorage () && originalField->isInstanceMember () &&
482+ " Expected stored property" );
483+ auto *parentDC = originalField->getDeclContext ();
484+ assert (parentDC->isTypeContext ());
485+ auto parentType = parentDC->getDeclaredTypeInContext ();
486+ auto *moduleDecl = originalField->getModuleContext ();
487+ auto parentTan = parentType->getAutoDiffTangentSpace (
488+ LookUpConformanceInModule (moduleDecl));
489+ // Error if parent nominal type does not conform to `Differentiable`.
490+ if (!parentTan) {
491+ return TangentPropertyInfo (
492+ TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable);
493+ }
494+ // Error if original stored property is `@noDerivative`.
495+ if (originalField->getAttrs ().hasAttribute <NoDerivativeAttr>()) {
496+ return TangentPropertyInfo (
497+ TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty);
498+ }
499+ // Error if original property's type does not conform to `Differentiable`.
500+ auto originalFieldTan = originalField->getType ()->getAutoDiffTangentSpace (
501+ LookUpConformanceInModule (moduleDecl));
502+ if (!originalFieldTan) {
503+ return TangentPropertyInfo (
504+ TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable);
505+ }
506+ auto parentTanType = parentTan->getType ();
507+ auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct ();
508+ // Error if parent `TangentVector` is not a struct.
509+ if (!parentTanStruct) {
510+ return TangentPropertyInfo (
511+ TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct);
512+ }
513+ // Find the corresponding field in the tangent space.
514+ VarDecl *tanField = nullptr ;
515+ // If `TangentVector` is the original struct, then the tangent property is the
516+ // original property.
517+ if (parentTanStruct == parentDC->getSelfStructDecl ()) {
518+ tanField = originalField;
519+ }
520+ // Otherwise, look up the field by name.
521+ else {
522+ auto tanFieldLookup =
523+ parentTanStruct->lookupDirect (originalField->getName ());
524+ llvm::erase_if (tanFieldLookup,
525+ [](ValueDecl *v) { return !isa<VarDecl>(v); });
526+ // Error if tangent property could not be found.
527+ if (tanFieldLookup.empty ()) {
528+ return TangentPropertyInfo (
529+ TangentPropertyInfo::Error::Kind::TangentPropertyNotFound);
530+ }
531+ tanField = cast<VarDecl>(tanFieldLookup.front ());
532+ }
533+ // Error if tangent property's type is not equal to the original property's
534+ // `TangentVector` type.
535+ auto originalFieldTanType = originalFieldTan->getType ();
536+ if (!originalFieldTanType->isEqual (tanField->getType ())) {
537+ return TangentPropertyInfo (
538+ TangentPropertyInfo::Error::Kind::TangentPropertyWrongType,
539+ originalFieldTanType);
540+ }
541+ // Error if tangent property is not a stored property.
542+ if (!tanField->hasStorage ()) {
543+ return TangentPropertyInfo (
544+ TangentPropertyInfo::Error::Kind::TangentPropertyNotStored);
545+ }
546+ // Otherwise, tangent property is valid.
547+ return TangentPropertyInfo (tanField);
548+ }
0 commit comments