55def f_2726 (X , y , n_splits , batch_size , epochs ):
66 """
77 Trains a simple neural network on provided data using k-fold cross-validation.
8- The network has one hidden layer with 50 neurons and ReLU activation, and
8+ The network has one hidden layer with 20 neurons and ReLU activation, and
99 an output layer with sigmoid activation for binary classification.
1010
1111 Parameters:
1212 X (numpy.array): The input data.
1313 y (numpy.array): The target data.
1414 n_splits (int): The number of splits for k-fold cross-validation. Default is 5.
1515 batch_size (int): The size of the batch used during training. Default is 32.
16- epochs (int): The number of epochs for training the model. Default is 10 .
16+ epochs (int): The number of epochs for training the model. Default is 1 .
1717
1818 Returns:
1919 list: A list containing the training history of the model for each fold. Each history
@@ -47,7 +47,7 @@ def f_2726(X, y, n_splits, batch_size, epochs):
4747 y_train , y_test = y [train_index ], y [test_index ]
4848
4949 model = tf .keras .models .Sequential ([
50- tf .keras .layers .Dense (50 , activation = 'relu' ),
50+ tf .keras .layers .Dense (20 , activation = 'relu' ),
5151 tf .keras .layers .Dense (1 , activation = 'sigmoid' )
5252 ])
5353
@@ -70,7 +70,7 @@ def setUp(self):
7070 self .y = np .random .randint (0 , 2 , 100 )
7171 self .n_splits = 5
7272 self .batch_size = 32
73- self .epochs = 10
73+ self .epochs = 1
7474
7575 def test_return_type (self ):
7676 """Test that the function returns a list."""
@@ -101,9 +101,9 @@ def test_effect_of_different_batch_sizes(self):
101101
102102 def test_effect_of_different_epochs (self ):
103103 """Test function behavior with different epochs."""
104- for epochs in [ 5 , 20 ]:
105- result = f_2726 (self .X , self .y , self .n_splits , self .batch_size , epochs )
106- self .assertEqual (len (result ), self .n_splits ) # Validating function execution
104+ epochs = 5
105+ result = f_2726 (self .X , self .y , self .n_splits , self .batch_size , epochs )
106+ self .assertEqual (len (result ), self .n_splits ) # Validating function execution
107107
108108
109109def run_tests ():
0 commit comments