@@ -102,6 +102,80 @@ public extension Layer {
102102 }
103103}
104104
105+ /// Adds helpers for standard feed-forward, sequential models.
106+ public extension Differentiable {
107+ @differentiable ( wrt: ( self , l1, l2) )
108+ func sequenced< L1: Layer , L2: Layer > (
109+ in context: Context , through l1: L1 , _ l2: L2 )
110+ -> L2 . Output
111+ where L1. Input == Self ,
112+ L1. Output == L2 . Input {
113+ let o1 = l1. applied ( to: self , in: context)
114+ return l2. applied ( to: o1, in: context)
115+ }
116+
117+ @differentiable ( wrt: ( self , l1, l2, l3) )
118+ func sequenced< L1: Layer , L2: Layer , L3: Layer > (
119+ in context: Context , through l1: L1 , _ l2: L2 , _ l3: L3 )
120+ -> L3 . Output
121+ where L1. Input == Self ,
122+ L1. Output == L2 . Input ,
123+ L2. Output == L3 . Input {
124+ let o1 = l1. applied ( to: self , in: context)
125+ let o2 = l2. applied ( to: o1, in: context)
126+ return l3. applied ( to: o2, in: context)
127+ }
128+
129+ @differentiable ( wrt: ( self , l1, l2, l3, l4) )
130+ func sequenced< L1: Layer , L2: Layer , L3: Layer , L4: Layer > (
131+ in context: Context , through l1: L1 , _ l2: L2 , _ l3: L3 , _ l4: L4 )
132+ -> L4 . Output
133+ where L1. Input == Self ,
134+ L1. Output == L2 . Input ,
135+ L2. Output == L3 . Input ,
136+ L3. Output == L4 . Input {
137+ let o1 = l1. applied ( to: self , in: context)
138+ let o2 = l2. applied ( to: o1, in: context)
139+ let o3 = l3. applied ( to: o2, in: context)
140+ return l4. applied ( to: o3, in: context)
141+ }
142+
143+ @differentiable ( wrt: ( self , l1, l2, l3, l4, l5) )
144+ func sequenced< L1: Layer , L2: Layer , L3: Layer , L4: Layer , L5: Layer > (
145+ in context: Context , through l1: L1 , _ l2: L2 , _ l3: L3 , _ l4: L4 , _ l5: L5 )
146+ -> L5 . Output
147+ where L1. Input == Self ,
148+ L1. Output == L2 . Input ,
149+ L2. Output == L3 . Input ,
150+ L3. Output == L4 . Input ,
151+ L4. Output == L5 . Input {
152+ let o1 = l1. applied ( to: self , in: context)
153+ let o2 = l2. applied ( to: o1, in: context)
154+ let o3 = l3. applied ( to: o2, in: context)
155+ let o4 = l4. applied ( to: o3, in: context)
156+ return l5. applied ( to: o4, in: context)
157+ }
158+
159+ @differentiable ( wrt: ( self , l1, l2, l3, l4, l5, l6) )
160+ func sequenced< L1: Layer , L2: Layer , L3: Layer , L4: Layer , L5: Layer , L6: Layer > (
161+ in context: Context , through l1: L1 , _ l2: L2 , _ l3: L3 , _ l4: L4 , _ l5: L5 , _ l6: L6 )
162+ -> L6 . Output
163+ where L1. Input == Self ,
164+ L1. Output == L2 . Input ,
165+ L2. Output == L3 . Input ,
166+ L3. Output == L4 . Input ,
167+ L4. Output == L5 . Input ,
168+ L5. Output == L6 . Input {
169+ let o1 = l1. applied ( to: self , in: context)
170+ let o2 = l2. applied ( to: o1, in: context)
171+ let o3 = l3. applied ( to: o2, in: context)
172+ let o4 = l4. applied ( to: o3, in: context)
173+ let o5 = l5. applied ( to: o4, in: context)
174+ return l6. applied ( to: o5, in: context)
175+ }
176+ }
177+
178+
105179/// A mutable, shareable, owning reference to a tensor.
106180public final class Parameter < Scalar: TensorFlowScalar > {
107181 public var value : Tensor < Scalar >
@@ -137,7 +211,7 @@ public extension Dense where Scalar.RawSignificand: FixedWidthInteger {
137211 init < G: RandomNumberGenerator > (
138212 inputSize: Int ,
139213 outputSize: Int ,
140- activation: @escaping Activation ,
214+ activation: @escaping Activation = identity ,
141215 generator: inout G
142216 ) {
143217 self . init ( weight: Tensor ( glorotUniform: [ Int32 ( inputSize) , Int32 ( outputSize) ] ,
@@ -146,7 +220,7 @@ public extension Dense where Scalar.RawSignificand: FixedWidthInteger {
146220 activation: activation)
147221 }
148222
149- init ( inputSize: Int , outputSize: Int , activation: @escaping Activation ) {
223+ init ( inputSize: Int , outputSize: Int , activation: @escaping Activation = identity ) {
150224 self . init ( inputSize: inputSize, outputSize: outputSize, activation: activation,
151225 generator: & PhiloxRandomNumberGenerator. global)
152226 }
0 commit comments