Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit ce67f8d

Browse files
authored
Remove explicit differentiation parameters. They are no longer required! (#33)
We've recently changed the type checker to improve the ergonomics of the `@differentiable` attribute. * [swiftlang/swift#22915](swiftlang/swift#22915): Explicit differentiation parameters are no longer required in a `@differentiable` attribute when the function has some arguments that do not conform to `Differentiable`. Those non-differentiable parameters will be skipped and the rest will be differentiated with respect to. * [swiftlang/swift#22877](swiftlang/swift#22877): On an instance method, when a `wrt:` is not specified, `self` is being implicitly included as a differentiation parameter. * [swiftlang/swift#22877](swiftlang/swift#22877): When a `@differentiable` requirement is not met, the `@differentiable` attribute fix-it will appear exactly as written in the original declaration instead of the most complex, canonical form. For instance, `@differentiable` instead of `@differentiable(wrt: (x))`. This greatly simplifies libraries and applications that use automatic differentiation. The protocol requirement `Layer.applied(to:in:)` becomes as simple as this: ```swift @differentiable func applied(to input: Input, in context: Context) -> Output ``` This PR updates deep learning APIs to use the simplest form of `@differentiable` possible. Hooray!
1 parent 3e8d86e commit ce67f8d

File tree

4 files changed

+22
-22
lines changed

4 files changed

+22
-22
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct Model: Layer {
2828
var layer2 = Dense<Float>(inputSize: hiddenSize, outputSize: hiddenSize, activation: relu)
2929
var layer3 = Dense<Float>(inputSize: hiddenSize, outputSize: 3, activation: identity)
3030

31-
@differentiable(wrt: (self, input))
31+
@differentiable
3232
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
3333
return input.sequenced(in: context, through: layer1, layer2, layer3)
3434
}

Sources/DeepLearning/Layer.swift

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public protocol Layer: Differentiable & KeyPathIterable
6363
/// - context: The contextual informance for the layer application, e.g. the current learning
6464
/// phase.
6565
/// - Returns: The output.
66-
@differentiable(wrt: (self, input))
66+
@differentiable
6767
func applied(to input: Input, in context: Context) -> Output
6868
}
6969

@@ -78,7 +78,7 @@ public extension Layer {
7878
///
7979
/// - Parameter input: The input to the layer.
8080
/// - Returns: The inference output.
81-
@differentiable(wrt: (self, input))
81+
@differentiable
8282
func inferring(from input: Input) -> Output {
8383
let context = Context(learningPhase: .inference)
8484
return applied(to: input, in: context)
@@ -104,7 +104,7 @@ public extension Layer {
104104

105105
/// Adds helpers for standard feed-forward, sequential models.
106106
public extension Differentiable {
107-
@differentiable(wrt: (self, l1, l2))
107+
@differentiable
108108
func sequenced<L1: Layer, L2: Layer>(
109109
in context: Context, through l1: L1, _ l2: L2)
110110
-> L2.Output
@@ -114,7 +114,7 @@ public extension Differentiable {
114114
return l2.applied(to: o1, in: context)
115115
}
116116

117-
@differentiable(wrt: (self, l1, l2, l3))
117+
@differentiable
118118
func sequenced<L1: Layer, L2: Layer, L3: Layer>(
119119
in context: Context, through l1: L1, _ l2: L2, _ l3: L3)
120120
-> L3.Output
@@ -126,7 +126,7 @@ public extension Differentiable {
126126
return l3.applied(to: o2, in: context)
127127
}
128128

129-
@differentiable(wrt: (self, l1, l2, l3, l4))
129+
@differentiable
130130
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer>(
131131
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4)
132132
-> L4.Output
@@ -140,7 +140,7 @@ public extension Differentiable {
140140
return l4.applied(to: o3, in: context)
141141
}
142142

143-
@differentiable(wrt: (self, l1, l2, l3, l4, l5))
143+
@differentiable
144144
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer>(
145145
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5)
146146
-> L5.Output
@@ -156,7 +156,7 @@ public extension Differentiable {
156156
return l5.applied(to: o4, in: context)
157157
}
158158

159-
@differentiable(wrt: (self, l1, l2, l3, l4, l5, l6))
159+
@differentiable
160160
func sequenced<L1: Layer, L2: Layer, L3: Layer, L4: Layer, L5: Layer, L6: Layer>(
161161
in context: Context, through l1: L1, _ l2: L2, _ l3: L3, _ l4: L4, _ l5: L5, _ l6: L6)
162162
-> L6.Output
@@ -196,7 +196,7 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
196196
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
197197
@noDerivative public let activation: Activation
198198

199-
@differentiable(wrt: (self, input))
199+
@differentiable
200200
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
201201
return activation(matmul(input, weight) + bias)
202202
}
@@ -230,7 +230,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
230230
@noDerivative public let strides: (Int32, Int32)
231231
@noDerivative public let padding: Padding
232232

233-
@differentiable(wrt: (self, input))
233+
@differentiable
234234
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
235235
return activation(input.convolved2D(withFilter: filter,
236236
strides: (1, strides.0, strides.1, 1),
@@ -286,7 +286,7 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
286286
/// The running variance.
287287
@noDerivative public let runningVariance: Parameter<Scalar>
288288

289-
@differentiable(wrt: (self, input))
289+
@differentiable
290290
private func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
291291
let positiveAxis = (input.rank + axis) % input.rank
292292
let mean = input.mean(alongAxes: [0, positiveAxis])
@@ -298,13 +298,13 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
298298
return (input - mean) * inv + offset
299299
}
300300

301-
@differentiable(wrt: (self, input))
301+
@differentiable
302302
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
303303
let inv = rsqrt(runningVariance.value + epsilon) * scale
304304
return (input - runningMean.value) * inv + offset
305305
}
306306

307-
@differentiable(wrt: (self, input), vjp: _vjpApplied(to:in:))
307+
@differentiable(vjp: _vjpApplied(to:in:))
308308
public func applied(to input: Tensor<Scalar>, in context: Context) -> Tensor<Scalar> {
309309
switch context.learningPhase {
310310
case .training:
@@ -360,7 +360,7 @@ public struct MaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
360360
self.padding = padding
361361
}
362362

363-
@differentiable(wrt: (self, input))
363+
@differentiable
364364
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
365365
return input.maxPooled(
366366
kernelSize: poolSize, strides: strides, padding: padding)
@@ -383,7 +383,7 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
383383
self.padding = padding
384384
}
385385

386-
@differentiable(wrt: (self, input))
386+
@differentiable
387387
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
388388
return input.averagePooled(
389389
kernelSize: poolSize, strides: strides, padding: padding)
@@ -410,7 +410,7 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
410410
self.epsilon = epsilon
411411
}
412412

413-
@differentiable(wrt: (self, input))
413+
@differentiable
414414
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
415415
let mean = input.mean(alongAxes: axis)
416416
let variance = input.variance(alongAxes: axis)
@@ -439,17 +439,17 @@ public struct Dropout<Scalar: TensorFlowFloatingPoint>: Layer
439439
self.probability = probability
440440
}
441441

442-
@differentiable(wrt: (self, input))
442+
@differentiable
443443
private func applyingTraining(to input: Tensor<Scalar>) -> Tensor<Scalar> {
444444
return input.droppingOut(probability: probability)
445445
}
446446

447-
@differentiable(wrt: (self, input))
447+
@differentiable
448448
private func applyingInference(to input: Tensor<Scalar>) -> Tensor<Scalar> {
449449
return input
450450
}
451451

452-
@differentiable(wrt: (self, input), vjp: _vjpApplied(to:in:))
452+
@differentiable(vjp: _vjpApplied(to:in:))
453453
public func applied(to input: Tensor<Scalar>, in context: Context) -> Tensor<Scalar> {
454454
switch context.learningPhase {
455455
case .training:
@@ -484,7 +484,7 @@ public struct UpSampling2D<Scalar: TensorFlowFloatingPoint>: Layer {
484484
self.size = size
485485
}
486486

487-
@differentiable(wrt: (self, input))
487+
@differentiable
488488
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
489489
let shape = input.shape
490490
let (batchSize, height, width, channels) = (shape[0], shape[1], shape[2], shape[3])

Tests/DeepLearningTests/SequentialTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ final class SequentialTests: XCTestCase {
2121
var dense1 = Dense<Float>(inputSize: 2, outputSize: 4, activation: relu)
2222
var dense2 = Dense<Float>(inputSize: 4, outputSize: 1, activation: relu)
2323

24-
@differentiable(wrt: (self, input))
24+
@differentiable
2525
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
2626
return input.sequenced(in: context, through: dense1, dense2)
2727
}

Tests/DeepLearningTests/TrivialModelTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ final class TrivialModelTests: XCTestCase {
3434
generator: &Classifier.generator
3535
)
3636
}
37-
@differentiable(wrt: (self, input))
37+
@differentiable
3838
func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
3939
let h1 = l1.applied(to: input, in: context)
4040
return l2.applied(to: h1, in: context)

0 commit comments

Comments
 (0)