@@ -338,3 +338,65 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
338338 }
339339 }
340340}
341+
342+ #if TENSORFLOW_USE_STANDARD_TOOLCHAIN
343+ @_spi ( Reflection) import Swift
344+
345+ func reflectionInit< T> ( type: T . Type , body: ( inout T , PartialKeyPath < T > ) -> Void ) -> T {
346+ let x = UnsafeMutablePointer< T> . allocate( capacity: 1 )
347+ defer { x. deallocate ( ) }
348+ if !_forEachFieldWithKeyPath( of: type) { name, kp in
349+ body ( & x. pointee, kp)
350+ return true
351+ } {
352+ fatalError ( " Cannot initialize \( T . self) because of unknown fields. " )
353+ }
354+ return x. move ( )
355+ }
356+
357+ extension TensorGroup {
358+ public static var _typeList : [ TensorDataType ] {
359+ var out = [ TensorDataType] ( )
360+ if !( _forEachFieldWithKeyPath ( of: Self . self) { name, kp in
361+ guard let valueType = type ( of: kp) . valueType as? TensorGroup . Type else { return false }
362+ out += valueType. _typeList
363+ return true
364+ } ) {
365+ fatalError ( " \( Self . self) does not have children that conform to TensorGroup. " )
366+ }
367+ return out
368+ }
369+ public static func initialize< Root> (
370+ _ base: inout Root , _ kp: PartialKeyPath < Root > ,
371+ _owning tensorHandles: UnsafePointer < CTensorHandle > ?
372+ ) {
373+ guard let kp = kp as? WritableKeyPath < Root , Self > else {
374+ fatalError ( " \( kp) is not \( WritableKeyPath < Root , Self > . self) " )
375+ }
376+ withUnsafeMutablePointer ( to: & base[ keyPath: kp] ) { v in
377+ v. initialize ( to: . init( _owning: tensorHandles) )
378+ }
379+ }
380+ public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
381+ var i = 0
382+ self = reflectionInit ( type: Self . self) { base, kp in
383+ guard let valueType = type ( of: kp) . valueType as? TensorGroup . Type else {
384+ fatalError ( " \( type ( of: kp) . valueType) does not conform to TensorGroup " )
385+ }
386+ valueType. initialize ( & base, kp, _owning: tensorHandles? . advanced ( by: i) )
387+ i += Int ( valueType. _tensorHandleCount)
388+ }
389+ }
390+ public func _unpackTensorHandles( into address: UnsafeMutablePointer < CTensorHandle > ? ) {
391+ var i = 0
392+ if !_forEachFieldWithKeyPath( of: Self . self) { name, kp in
393+ guard let x = self [ keyPath: kp] as? TensorGroup else { return false }
394+ x. _unpackTensorHandles ( into: address? . advanced ( by: i) )
395+ i += Int ( type ( of: x) . _tensorHandleCount)
396+ return true
397+ } {
398+ fatalError ( " Cannot unpack \( Self . self) because of non-TensorGroup fields. " )
399+ }
400+ }
401+ }
402+ #endif
0 commit comments