Skip to content

Commit cf79591

Browse files
committed
Make conv2d layer API consistent with Keras (filters, kernel_size)
1 parent 00e9802 commit cf79591

File tree

6 files changed

+34
-34
lines changed

6 files changed

+34
-34
lines changed

src/nf_conv2d_layer.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ module nf_conv2d_layer
1313
integer :: width
1414
integer :: height
1515
integer :: channels
16-
integer :: window_size
16+
integer :: kernel_size
1717
integer :: filters
1818

1919
real, allocatable :: biases(:) ! size(filters)
@@ -29,10 +29,11 @@ module nf_conv2d_layer
2929
end type conv2d_layer
3030

3131
interface conv2d_layer
32-
pure module function conv2d_layer_cons(window_size, filters, activation) result(res)
32+
pure module function conv2d_layer_cons(filters, kernel_size, activation) &
33+
result(res)
3334
!! `conv2d_layer` constructor function
34-
integer, intent(in) :: window_size
3535
integer, intent(in) :: filters
36+
integer, intent(in) :: kernel_size
3637
character(*), intent(in) :: activation
3738
type(conv2d_layer) :: res
3839
end function conv2d_layer_cons

src/nf_conv2d_layer_submodule.f90

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
contains
88

9-
pure module function conv2d_layer_cons(window_size, filters, activation) result(res)
9+
pure module function conv2d_layer_cons(filters, kernel_size, activation) result(res)
1010
implicit none
11-
integer, intent(in) :: window_size
1211
integer, intent(in) :: filters
12+
integer, intent(in) :: kernel_size
1313
character(*), intent(in) :: activation
1414
type(conv2d_layer) :: res
15-
res % window_size = window_size
15+
res % kernel_size = kernel_size
1616
res % filters = filters
1717
call res % set_activation(activation)
1818
end function conv2d_layer_cons
@@ -24,21 +24,21 @@ module subroutine init(self, input_shape)
2424
integer, intent(in) :: input_shape(:)
2525

2626
self % channels = input_shape(1)
27-
self % width = input_shape(2) - self % window_size + 1
28-
self % height = input_shape(3) - self % window_size + 1
27+
self % width = input_shape(2) - self % kernel_size + 1
28+
self % height = input_shape(3) - self % kernel_size + 1
2929

3030
! Output of shape filters x width x height
3131
allocate(self % output(self % filters, self % width, self % height))
3232
self % output = 0
3333

3434
! Kernel of shape filters x channels x width x height
3535
allocate(self % kernel(self % filters, self % channels, &
36-
self % window_size, self % window_size))
36+
self % kernel_size, self % kernel_size))
3737

3838
! Initialize the kernel with random values with a normal distribution.
3939
self % kernel = randn(self % filters, self % channels, &
40-
self % window_size, self % window_size) &
41-
/ self % window_size**2 !TODO window_width * window_height
40+
self % kernel_size, self % kernel_size) &
41+
/ self % kernel_size**2 !TODO kernel_width * kernel_height
4242

4343
allocate(self % biases(self % filters))
4444
self % biases = 0
@@ -54,7 +54,6 @@ pure module subroutine forward(self, input)
5454
integer :: istart, iend
5555
integer :: jstart, jend
5656
integer :: i, j, n
57-
integer :: ii, jj
5857
integer :: iws, iwe, jws, jwe
5958
integer :: half_window
6059

@@ -64,23 +63,23 @@ pure module subroutine forward(self, input)
6463
input_height = size(input, dim=3)
6564

6665
! Half-window is 1 for window size 3; 2 for window size 5; etc.
67-
half_window = self % window_size / 2
66+
half_window = self % kernel_size / 2
6867

6968
! Determine the start and end indices for the width and height dimensions
7069
! of the input that correspond to the center of each window.
71-
istart = half_window + 1 ! TODO window_width
72-
jstart = half_window + 1 ! TODO window_height
70+
istart = half_window + 1 ! TODO kernel_width
71+
jstart = half_window + 1 ! TODO kernel_height
7372
iend = input_width - istart + 1
7473
jend = input_height - jstart + 1
7574

7675
convolution: do concurrent(i = istart:iend, j = jstart:jend)
7776

7877
! Start and end indices of the input data on the filter window
7978
! iws and jws are also coincidentally the indices of the output matrix
80-
iws = i - half_window ! TODO window_width
81-
iwe = i + half_window ! TODO window_width
82-
jws = j - half_window ! TODO window_height
83-
jwe = j + half_window ! TODO window_height
79+
iws = i - half_window ! TODO kernel_width
80+
iwe = i + half_window ! TODO kernel_width
81+
jws = j - half_window ! TODO kernel_height
82+
jwe = j + half_window ! TODO kernel_height
8483

8584
! This computes the inner tensor product, sum(w_ij * x_ij), for each
8685
! filter, and we add bias b_n to it.

src/nf_layer_constructors.f90

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ pure module function dense(layer_size, activation) result(res)
8484
!! Resulting layer instance
8585
end function dense
8686

87-
pure module function conv2d(window_size, filters, activation) result(res)
87+
pure module function conv2d(filters, kernel_size, activation) result(res)
8888
!! 2-d convolutional layer constructor.
8989
!!
9090
!! This layer is for building 2-d convolutional network.
@@ -98,13 +98,13 @@ pure module function conv2d(window_size, filters, activation) result(res)
9898
!! ```
9999
!! use nf, only :: conv2d, layer
100100
!! type(layer) :: conv2d_layer
101-
!! conv2d_layer = dense(window_size=3, filters=32)
102-
!! conv2d_layer = dense(window_size=3, filters=32, activation='relu')
101+
!! conv2d_layer = dense(filters=32, kernel_size=3)
102+
!! conv2d_layer = dense(filters=32, kernel_size=3, activation='relu')
103103
!! ```
104-
integer, intent(in) :: window_size
105-
!! Width of the convolution window, commonly 3 or 5
106104
integer, intent(in) :: filters
107105
!! Number of filters in the output of the layer
106+
integer, intent(in) :: kernel_size
107+
!! Width of the convolution window, commonly 3 or 5
108108
character(*), intent(in), optional :: activation
109109
!! Activation function (default 'sigmoid')
110110
type(layer) :: res

src/nf_layer_constructors_submodule.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ pure module function dense(layer_size, activation) result(res)
5151
end function dense
5252

5353

54-
pure module function conv2d(window_size, filters, activation) result(res)
55-
integer, intent(in) :: window_size
54+
pure module function conv2d(filters, kernel_size, activation) result(res)
5655
integer, intent(in) :: filters
56+
integer, intent(in) :: kernel_size
5757
character(*), intent(in), optional :: activation
5858
type(layer) :: res
5959

@@ -67,7 +67,7 @@ pure module function conv2d(window_size, filters, activation) result(res)
6767

6868
allocate( &
6969
res % p, &
70-
source=conv2d_layer(window_size, filters, res % activation) &
70+
source=conv2d_layer(filters, kernel_size, res % activation) &
7171
)
7272

7373
end function conv2d

test/test_conv2d_layer.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ program test_conv2d_layer
77
implicit none
88

99
type(layer) :: conv_layer, input_layer
10-
integer, parameter :: window_size = 3, filters = 32
10+
integer, parameter :: filters = 32, kernel_size=3
1111
real, allocatable :: sample_input(:,:,:), output(:,:,:)
1212
real, parameter :: tolerance = 1e-7
1313
logical :: ok = .true.
1414

15-
conv_layer = conv2d(window_size, filters)
15+
conv_layer = conv2d(filters, kernel_size)
1616

1717
if (.not. conv_layer % name == 'conv2d') then
1818
ok = .false.
@@ -52,7 +52,7 @@ program test_conv2d_layer
5252
sample_input = 0
5353

5454
input_layer = input([1, 3, 3])
55-
conv_layer = conv2d(window_size, filters)
55+
conv_layer = conv2d(filters, kernel_size)
5656
call conv_layer % init(input_layer)
5757

5858
select type(this_layer => input_layer % p); type is(input3d_layer)

test/test_conv2d_network.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ program test_conv2d_network
1212
! 3-layer convolutional network
1313
net = network([ &
1414
input([3, 32, 32]), &
15-
conv2d(window_size=3, filters=16), &
16-
conv2d(window_size=3, filters=32) &
15+
conv2d(filters=16, kernel_size=3), &
16+
conv2d(filters=32, kernel_size=3) &
1717
])
1818

1919
if (.not. size(net % layers) == 3) then
@@ -33,9 +33,9 @@ program test_conv2d_network
3333
end if
3434

3535
if (ok) then
36-
print '(a)', 'test_dense_network: All tests passed.'
36+
print '(a)', 'test_conv2d_network: All tests passed.'
3737
else
38-
write(stderr, '(a)') 'test_dense_network: One or more tests failed.'
38+
write(stderr, '(a)') 'test_conv2d_network: One or more tests failed.'
3939
stop 1
4040
end if
4141

0 commit comments

Comments
 (0)