1212 use nf_layer, only: layer
1313 use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1414 use nf_loss, only: quadratic_derivative
15- use nf_optimizers, only: sgd
15+ use nf_optimizers, only: optimizer_base_type, sgd
1616 use nf_parallel, only: tile_indices
1717
1818 implicit none
@@ -426,7 +426,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
426426 real , intent (in ) :: output_data(:,:)
427427 integer , intent (in ) :: batch_size
428428 integer , intent (in ) :: epochs
429- type (sgd ), intent (in ) :: optimizer
429+ class(optimizer_base_type ), intent (in ) :: optimizer
430430
431431 real :: pos
432432 integer :: dataset_size
@@ -439,26 +439,31 @@ module subroutine train(self, input_data, output_data, batch_size, &
439439 epoch_loop: do n = 1 , epochs
440440 batch_loop: do i = 1 , dataset_size / batch_size
441441
442- ! Pull a random mini-batch from the dataset
443- call random_number (pos)
444- batch_start = int (pos * (dataset_size - batch_size + 1 )) + 1
445- batch_end = batch_start + batch_size - 1
446-
447- ! FIXME shuffle in a way that doesn't require co_broadcast
448- call co_broadcast(batch_start, 1 )
449- call co_broadcast(batch_end, 1 )
450-
451- ! Distribute the batch in nearly equal pieces to all images
452- indices = tile_indices(batch_size)
453- istart = indices(1 ) + batch_start - 1
454- iend = indices(2 ) + batch_start - 1
455-
456- do concurrent(j = istart:iend)
457- call self % forward(input_data(:,j))
458- call self % backward(output_data(:,j))
459- end do
460-
461- call self % update(optimizer % learning_rate / batch_size)
442+ ! Pull a random mini-batch from the dataset
443+ call random_number (pos)
444+ batch_start = int (pos * (dataset_size - batch_size + 1 )) + 1
445+ batch_end = batch_start + batch_size - 1
446+
447+ ! FIXME shuffle in a way that doesn't require co_broadcast
448+ call co_broadcast(batch_start, 1 )
449+ call co_broadcast(batch_end, 1 )
450+
451+ ! Distribute the batch in nearly equal pieces to all images
452+ indices = tile_indices(batch_size)
453+ istart = indices(1 ) + batch_start - 1
454+ iend = indices(2 ) + batch_start - 1
455+
456+ do concurrent(j = istart:iend)
457+ call self % forward(input_data(:,j))
458+ call self % backward(output_data(:,j))
459+ end do
460+
461+ select type (optimizer)
462+ type is (sgd)
463+ call self % update(optimizer % learning_rate / batch_size)
464+ class default
465+ error stop ' Unsupported optimizer'
466+ end select
462467
463468 end do batch_loop
464469 end do epoch_loop
0 commit comments