@@ -18,6 +18,7 @@ module nf_activation
1818 public :: softplus
1919 public :: step
2020 public :: tanhf
21+ public :: celu
2122
2223 type, abstract :: activation_function
2324 contains
@@ -140,6 +141,15 @@ end function eval_3d_i
140141 procedure :: eval_3d_prime = > eval_3d_tanh_prime
141142 end type tanhf
142143
144+ type, extends(activation_function) :: celu
145+ real :: alpha = 1.0 ! Pytorch default
146+ contains
147+ procedure :: eval_1d = > eval_1d_celu
148+ procedure :: eval_1d_prime = > eval_1d_celu_prime
149+ procedure :: eval_3d = > eval_3d_celu
150+ procedure :: eval_3d_prime = > eval_3d_celu_prime
151+ end type celu
152+
143153contains
144154
145155 pure function eval_1d_elu (self , x ) result(res)
@@ -522,6 +532,54 @@ pure function eval_3d_tanh_prime(self, x) result(res)
522532 res = 1 - tanh (x)** 2
523533 end function eval_3d_tanh_prime
524534
535+ pure function eval_1d_celu (self , x ) result(res)
536+ ! Celu activation function.
537+ class(celu), intent (in ) :: self
538+ real , intent (in ) :: x(:)
539+ real :: res(size (x))
540+ where (x >= 0.0 )
541+ res = x
542+ else where
543+ res = self % alpha * (exp (x / self % alpha) - 1.0 )
544+ end where
545+ end function
546+
547+ pure function eval_1d_celu_prime (self , x ) result(res)
548+ ! Celu activation function.
549+ class(celu), intent (in ) :: self
550+ real , intent (in ) :: x(:)
551+ real :: res(size (x))
552+ where (x >= 0.0 )
553+ res = 1.0
554+ else where
555+ res = exp (x / self % alpha)
556+ end where
557+ end function
558+
559+ pure function eval_3d_celu (self , x ) result(res)
560+ ! Celu activation function.
561+ class(celu), intent (in ) :: self
562+ real , intent (in ) :: x(:,:,:)
563+ real :: res(size (x,1 ),size (x,2 ),size (x,3 ))
564+ where (x >= 0.0 )
565+ res = x
566+ else where
567+ res = self % alpha * (exp (x / self % alpha) - 1.0 )
568+ end where
569+ end function
570+
571+ pure function eval_3d_celu_prime (self , x ) result(res)
572+ ! Celu activation function.
573+ class(celu), intent (in ) :: self
574+ real , intent (in ) :: x(:,:,:)
575+ real :: res(size (x,1 ),size (x,2 ),size (x,3 ))
576+ where (x >= 0.0 )
577+ res = 1.0
578+ else where
579+ res = exp (x / self % alpha)
580+ end where
581+ end function
582+
525583 pure function get_name (self ) result(name)
526584 ! ! Return the name of the activation function.
527585 ! !
@@ -556,6 +614,8 @@ pure function get_name(self) result(name)
556614 name = ' step'
557615 class is (tanhf)
558616 name = ' tanh'
617+ class is (celu)
618+ name = ' celu'
559619 class default
560620 error stop ' Unknown activation function type.'
561621 end select
0 commit comments