@@ -1297,7 +1297,7 @@ extension Tensor {
12971297 }
12981298
12991299 @inlinable
1300- @differentiable ( reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
1300+ // @differentiable(reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
13011301 internal subscript( _ indexPath: IndexPath ) -> Tensor {
13021302 get {
13031303 let device = self . device
@@ -1323,7 +1323,7 @@ extension Tensor {
13231323 }
13241324
13251325 @inlinable
1326- @differentiable ( reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
1326+ // @differentiable(reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
13271327 public subscript( _ ranges: TensorRangeExpression ... ) -> Tensor {
13281328 get {
13291329 return self [ { IndexPath ( { ranges. map { $0. tensorRange } } ( ) ) } ( ) ]
@@ -1334,27 +1334,27 @@ extension Tensor {
13341334 }
13351335}
13361336
1337- extension Tensor where Scalar : TensorFlowFloatingPoint {
1338- @usableFromInline
1339- @derivative ( of: subscript)
1340- internal func _vjpSubscript(
1341- _ indexPath: IndexPath
1342- ) -> ( value: Tensor , pullback: ( Tensor ) -> Tensor ) {
1343- return (
1344- self [ indexPath] ,
1345- { [ shape = shapeTensor] v in
1346- _Raw. stridedSliceGrad (
1347- shape: shape, begin: Tensor < Int32 > ( indexPath. begin, on: device) ,
1348- end: Tensor < Int32 > ( indexPath. end, on: device) ,
1349- strides: Tensor < Int32 > ( indexPath. strides, on: device) , dy: v,
1350- beginMask: indexPath. beginMask,
1351- endMask: indexPath. endMask, ellipsisMask: indexPath. ellipsisMask,
1352- newAxisMask: indexPath. newAxisMask,
1353- shrinkAxisMask: indexPath. squeezeAxisMask)
1354- }
1355- )
1356- }
1357- }
1337+ // extension Tensor {
1338+ // @usableFromInline
1339+ // @derivative(of: subscript)
1340+ // internal func _vjpSubscript(
1341+ // _ indexPath: IndexPath
1342+ // ) -> (value: Tensor, pullback: (Tensor) -> Tensor) {
1343+ // return (
1344+ // self[indexPath],
1345+ // { [shape = shapeTensor] v in
1346+ // _Raw.stridedSliceGrad(
1347+ // shape: shape, begin: Tensor<Int32>(indexPath.begin, on: device),
1348+ // end: Tensor<Int32>(indexPath.end, on: device),
1349+ // strides: Tensor<Int32>(indexPath.strides, on: device), dy: v,
1350+ // beginMask: indexPath.beginMask,
1351+ // endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask,
1352+ // newAxisMask: indexPath.newAxisMask,
1353+ // shrinkAxisMask: indexPath.squeezeAxisMask)
1354+ // }
1355+ // )
1356+ // }
1357+ // }
13581358
13591359extension Tensor . IndexPath {
13601360 @inlinable
0 commit comments