Skip to content

Commit c7e682b

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents a4d5cc8 + ed8b340 commit c7e682b

26 files changed

+2230
-41
lines changed

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_library(neural-fortran
2020
src/nf/nf_base_layer.f90
2121
src/nf/nf_conv2d_layer.f90
2222
src/nf/nf_conv2d_layer_submodule.f90
23+
src/nf/nf_cross_attention_layer.f90
2324
src/nf/nf_datasets.f90
2425
src/nf/nf_datasets_submodule.f90
2526
src/nf/nf_datasets_mnist.f90
@@ -40,13 +41,17 @@ add_library(neural-fortran
4041
src/nf/nf_layer_submodule.f90
4142
src/nf/nf_locally_connected_1d_submodule.f90
4243
src/nf/nf_locally_connected_1d.f90
44+
src/nf/nf_linear2d_layer.f90
45+
src/nf/nf_linear2d_layer_submodule.f90
4346
src/nf/nf_loss.f90
4447
src/nf/nf_loss_submodule.f90
4548
src/nf/nf_maxpool1d_layer.f90
4649
src/nf/nf_maxpool1d_layer_submodule.f90
4750
src/nf/nf_maxpool2d_layer.f90
4851
src/nf/nf_maxpool2d_layer_submodule.f90
4952
src/nf/nf_metrics.f90
53+
src/nf/nf_multihead_attention.f90
54+
src/nf/nf_multihead_attention_submodule.f90
5055
src/nf/nf_network.f90
5156
src/nf/nf_network_submodule.f90
5257
src/nf/nf_optimizers.f90
@@ -57,8 +62,11 @@ add_library(neural-fortran
5762
src/nf/nf_reshape_layer_submodule.f90
5863
src/nf/nf_reshape2d_layer.f90
5964
src/nf/nf_reshape2d_layer_submodule.f90
65+
src/nf/nf_self_attention_layer.f90
6066
src/nf/io/nf_io_binary.f90
6167
src/nf/io/nf_io_binary_submodule.f90
68+
src/nf/nf_dropout_layer.f90
69+
src/nf/nf_dropout_layer_submodule.f90
6270
)
6371

6472
target_link_libraries(neural-fortran PRIVATE)

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3030
| Layer type | Constructor name | Supported input layers | Rank of output array | Forward pass | Backward pass |
3131
|------------|------------------|------------------------|----------------------|--------------|---------------|
3232
| Input | `input` | n/a | 1, 2, 3 | n/a | n/a |
33-
| Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 |||
33+
| Dense (fully-connected) | `dense` | `input1d`, `dense`, `dropout`, `flatten` | 1 |||
34+
| Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 |||
3435
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || ✅(*) |
3536
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
37+
| Linear (2-d) | `linear2d` | `input2d`, `linear2d`, `self_attention` | 2 |||
38+
| Self-attention | `self_attention` | `input2d`, `linear2d`, `self_attention` | 2 |||
3639
| Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
3740
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||
3841

example/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ foreach(execid
77
simple
88
sine
99
quadratic
10+
mha_simple
1011
)
1112
add_executable(${execid} ${execid}.f90)
1213
target_link_libraries(${execid} PRIVATE

example/dense_mnist.f90

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
program dense_mnist
22

3-
use nf, only: dense, input, network, sgd, label_digits, load_mnist, corr
3+
use nf, only: dense, input, network, sgd, label_digits, load_mnist, corr, relu, softmax, dropout
44

55
implicit none
66

@@ -17,8 +17,9 @@ program dense_mnist
1717

1818
net = network([ &
1919
input(784), &
20-
dense(30), &
21-
dense(10) &
20+
dense(64, relu()), &
21+
dropout(0.2), &
22+
dense(10, softmax()) &
2223
])
2324
num_epochs = 10
2425

@@ -32,7 +33,7 @@ program dense_mnist
3233
call net % train( &
3334
training_images, &
3435
label_digits(training_labels), &
35-
batch_size=100, &
36+
batch_size=128, &
3637
epochs=1, &
3738
optimizer=sgd(learning_rate=3.) &
3839
)

example/mha_simple.f90

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
program mha_simple
2+
use nf, only: dense, input, network, sgd, self_attention, flatten
3+
implicit none
4+
type(network) :: net
5+
real, allocatable :: x(:, :), y(:)
6+
integer, parameter :: num_iterations = 500
7+
integer :: n
8+
9+
print '("Simple")'
10+
print '(60("="))'
11+
12+
net = network([ &
13+
input(3, 8), &
14+
self_attention(4), &
15+
flatten(), &
16+
dense(2) &
17+
])
18+
19+
call net % print_info()
20+
21+
allocate(x(3, 8))
22+
call random_number(x)
23+
24+
y = [0.123456, 0.246802]
25+
26+
do n = 0, num_iterations
27+
28+
call net % forward(x)
29+
call net % backward(y)
30+
call net % update(optimizer=sgd(learning_rate=1.))
31+
32+
if (mod(n, 50) == 0) &
33+
print '(i4,2(3x,f8.6))', n, net % predict(x)
34+
35+
end do
36+
37+
end program mha_simple

src/nf.f90

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@ module nf
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
66
conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape, reshape2d, locally_connected_1d
7+
conv2d, &
8+
dense, &
9+
dropout, &
10+
flatten, &
11+
input, &
12+
linear2d, &
13+
maxpool2d, &
14+
reshape, &
15+
self_attention
716
use nf_loss, only: mse, quadratic
817
use nf_metrics, only: corr, maxabs
918
use nf_network, only: network
@@ -12,4 +21,6 @@ module nf
1221
gaussian, linear, relu, leaky_relu, &
1322
sigmoid, softmax, softplus, step, tanhf, &
1423
celu
24+
use nf_linear2d_layer, only: linear2d_layer
25+
use nf_multihead_attention_layer, only: multihead_attention_layer
1526
end module nf
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
module nf_cross_attention_layer
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf_activation, only: softmax
4+
use nf_linear2d_layer, only: linear2d_layer
5+
use nf_multihead_attention_layer, only: multihead_attention_layer
6+
7+
implicit none
8+
9+
type, extends(multihead_attention_layer) :: cross_attention_layer
10+
!! Cross Attention Layer
11+
!! Source:
12+
!! Bahdanau, D. (2014)
13+
!! Neural machine translation by jointly learning to align and translate.
14+
!! https://arxiv.org/pdf/1409.0473
15+
real, allocatable :: gradient(:, :, :)
16+
contains
17+
procedure :: forward
18+
procedure :: backward
19+
procedure :: init
20+
end type cross_attention_layer
21+
22+
interface cross_attention_layer
23+
module function cross_attention_layer_cons(n_heads) result(res)
24+
!! This function returns the `cross_attention_layer` instance.
25+
integer, intent(in) :: sequence_length, model_dimension, n_heads
26+
type(cross_attention_layer) :: res
27+
end function cross_attention_layer_cons
28+
end interface cross_attention_layer
29+
30+
contains
31+
module function cross_attention_layer_cons(n_heads) result(res)
32+
!! This function returns the `cross_attention_layer` instance.
33+
integer, intent(in) :: n_heads
34+
type(cross_attention_layer) :: res
35+
res % n_heads = n_heads
36+
end function cross_attention_layer_cons
37+
38+
pure module subroutine backward(self, input, gradient)
39+
!! Cross Attention Back propagation
40+
class(cross_attention_layer), intent(in out) :: self
41+
real, intent(in) :: input(:, :, :)
42+
real, intent(in) :: gradient(:, :)
43+
44+
call self % common_backward(input(1, :, :), gradient)
45+
self % gradient(1, :, :) = self % query_layer % gradient
46+
self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient
47+
end subroutine backward
48+
49+
pure module subroutine forward(self, input)
50+
!! Cross Attention Forward propagation
51+
!! Input Shape (kind, sequence_length, model_dimension)
52+
!! where kind is 1 for Query and 2 for Key-Value
53+
class(cross_attention_layer), intent(in out) :: self
54+
real, intent(in) :: input(:, :, :)
55+
56+
call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :))
57+
end subroutine forward
58+
59+
module subroutine init(self, input_shape)
60+
class(cross_attention_layer), intent(in out) :: self
61+
integer, intent(in) :: input_shape(:)
62+
63+
call self % init_base(input_shape)
64+
allocate(self % gradient(2, self % sequence_length, self % model_dimension))
65+
end subroutine init
66+
end module nf_cross_attention_layer

src/nf/nf_dropout_layer.f90

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
module nf_dropout_layer
2+
3+
!! Dropout layer by Srivastava et al. (2014).
4+
!!
5+
!! Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I. and
6+
!! Salakhutdinov, R., 2014. Dropout: a simple way to prevent neural networks
7+
!! from overfitting. The Journal of Machine Learning Research, 16(1),
8+
!! pp.1929-1958.
9+
10+
use nf_base_layer, only: base_layer
11+
12+
implicit none
13+
14+
private
15+
public :: dropout_layer
16+
17+
type, extends(base_layer) :: dropout_layer
18+
!! Concrete implementation of a dropout layer type
19+
20+
integer :: input_size = 0
21+
22+
real, allocatable :: output(:)
23+
real, allocatable :: gradient(:)
24+
real, allocatable :: mask(:) ! binary mask for dropout
25+
26+
real :: dropout_rate ! probability of dropping a neuron
27+
real :: scale ! scale factor to preserve the input sum
28+
logical :: training = .true. ! set to .false. for inference
29+
30+
contains
31+
32+
procedure :: backward
33+
procedure :: forward
34+
procedure :: init
35+
36+
end type dropout_layer
37+
38+
interface dropout_layer
39+
module function dropout_layer_cons(rate) &
40+
result(res)
41+
!! This function returns the `dropout_layer` instance.
42+
real, intent(in) :: rate
43+
!! Dropout rate
44+
type(dropout_layer) :: res
45+
!! dropout_layer instance
46+
end function dropout_layer_cons
47+
end interface dropout_layer
48+
49+
interface
50+
51+
pure module subroutine backward(self, gradient)
52+
!! Apply the backward gradient descent pass.
53+
!! Only weight and bias gradients are updated in this subroutine,
54+
!! while the weights and biases themselves are untouched.
55+
class(dropout_layer), intent(in out) :: self
56+
!! Dropout layer instance
57+
real, intent(in) :: gradient(:)
58+
!! Gradient from the next layer
59+
end subroutine backward
60+
61+
module subroutine forward(self, input)
62+
!! Propagate forward the layer.
63+
!! Calling this subroutine updates the values of a few data components
64+
!! of `dropout_layer` that are needed for the backward pass.
65+
class(dropout_layer), intent(in out) :: self
66+
!! Dense layer instance
67+
real, intent(in) :: input(:)
68+
!! Input from the previous layer
69+
end subroutine forward
70+
71+
module subroutine init(self, input_shape)
72+
!! Initialize the layer data structures.
73+
!!
74+
!! This is a deferred procedure from the `base_layer` abstract type.
75+
class(dropout_layer), intent(in out) :: self
76+
!! Dropout layer instance
77+
integer, intent(in) :: input_shape(:)
78+
!! Shape of the input layer
79+
end subroutine init
80+
81+
end interface
82+
83+
end module nf_dropout_layer
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
submodule (nf_dropout_layer) nf_dropout_layer_submodule
2+
use nf_random, only: shuffle
3+
!! This submodule implements the procedures defined in the
4+
!! nf_dropout_layer module.
5+
6+
contains
7+
8+
module function dropout_layer_cons(rate) result(res)
9+
real, intent(in) :: rate
10+
type(dropout_layer) :: res
11+
res % dropout_rate = rate
12+
res % scale = 1 / (1 - rate)
13+
end function dropout_layer_cons
14+
15+
16+
module subroutine init(self, input_shape)
17+
class(dropout_layer), intent(in out) :: self
18+
integer, intent(in) :: input_shape(:)
19+
20+
self % input_size = input_shape(1)
21+
22+
! Allocate arrays
23+
allocate(self % output(self % input_size))
24+
allocate(self % gradient(self % input_size))
25+
allocate(self % mask(self % input_size))
26+
27+
! Initialize arrays
28+
self % output = 0
29+
self % gradient = 0
30+
self % mask = 1 ! Default mask is all ones (no dropout)
31+
32+
end subroutine init
33+
34+
35+
module subroutine forward(self, input)
36+
class(dropout_layer), intent(in out) :: self
37+
real, intent(in) :: input(:)
38+
39+
! Generate random mask for dropout, training mode only
40+
if (self % training) then
41+
42+
! Set the first dropout_rate number of elements to 0, the rest to 1,
43+
! and shuffle. Note that the selection of the elements rounds down to
44+
! the nearest integer, so in cases where size(input) * dropout_rate is
45+
! not an integer, the actual dropout rate will be slightly lower.
46+
self % mask = 1
47+
self % mask(:int(size(self % mask) * self % dropout_rate)) = 0
48+
call shuffle(self % mask)
49+
50+
! Apply dropout mask
51+
self % output = input * self % mask * self % scale
52+
53+
else
54+
! In inference mode, we don't apply dropout; simply pass through the input
55+
self % output = input
56+
57+
end if
58+
59+
end subroutine forward
60+
61+
62+
pure module subroutine backward(self, gradient)
63+
class(dropout_layer), intent(in out) :: self
64+
real, intent(in) :: gradient(:)
65+
self % gradient = gradient * self % mask * self % scale
66+
end subroutine backward
67+
68+
end submodule nf_dropout_layer_submodule

src/nf/nf_layer.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ end subroutine backward_3d
9191

9292
interface
9393

94-
pure module subroutine forward(self, input)
94+
module subroutine forward(self, input)
9595
!! Apply a forward pass on the layer.
9696
!! This changes the internal state of the layer.
9797
!! This is normally called internally by the `network % forward`

0 commit comments

Comments
 (0)