2424from test_all import set_up
2525from torch import nn , optim
2626from trainers import (
27- TrainEvents ,
2827 create_trainers ,
2928 evaluate_function ,
30- train_events_to_attr ,
3129 train_function ,
3230)
3331from utils import (
@@ -92,88 +90,6 @@ def test_get_logger(tmp_path):
9290 assert isinstance (logger_handler , types ), "Should be Ignite provided loggers or None"
9391
9492
95- def test_train_fn ():
96- model , optimizer , device , loss_fn , batch = set_up ()
97- engine = Engine (lambda e , b : 1 )
98- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
99- backward = MagicMock ()
100- optim = MagicMock ()
101- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED , backward )
102- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED , optim )
103- config = Namespace (use_amp = False )
104- output = train_function (config , engine , batch , model , loss_fn , optimizer , device )
105- assert isinstance (output , dict )
106- assert hasattr (engine .state , "backward_completed" )
107- assert hasattr (engine .state , "optim_step_completed" )
108- assert engine .state .backward_completed == 1
109- assert engine .state .optim_step_completed == 1
110- assert backward .call_count == 1
111- assert optim .call_count == 1
112- assert backward .called
113- assert optim .called
114-
115-
116- def test_train_fn_event_filter ():
117- model , optimizer , device , loss_fn , batch = set_up ()
118- config = Namespace (use_amp = False )
119- engine = Engine (lambda e , b : train_function (config , e , b , model , loss_fn , optimizer , device ))
120- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
121- backward = MagicMock ()
122- optim = MagicMock ()
123- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED (event_filter = lambda _ , x : (x % 2 == 0 ) or x == 3 ), backward )
124- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED (event_filter = lambda _ , x : (x % 2 == 0 ) or x == 3 ), optim )
125- engine .run ([batch ] * 5 )
126- assert hasattr (engine .state , "backward_completed" )
127- assert hasattr (engine .state , "optim_step_completed" )
128- assert engine .state .backward_completed == 5
129- assert engine .state .optim_step_completed == 5
130- assert backward .call_count == 3
131- assert optim .call_count == 3
132- assert backward .called
133- assert optim .called
134-
135-
136- def test_train_fn_every ():
137- model , optimizer , device , loss_fn , batch = set_up ()
138-
139- config = Namespace (use_amp = False )
140- engine = Engine (lambda e , b : train_function (config , e , b , model , loss_fn , optimizer , device ))
141- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
142- backward = MagicMock ()
143- optim = MagicMock ()
144- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED (every = 2 ), backward )
145- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED (every = 2 ), optim )
146- engine .run ([batch ] * 5 )
147- assert hasattr (engine .state , "backward_completed" )
148- assert hasattr (engine .state , "optim_step_completed" )
149- assert engine .state .backward_completed == 5
150- assert engine .state .optim_step_completed == 5
151- assert backward .call_count == 2
152- assert optim .call_count == 2
153- assert backward .called
154- assert optim .called
155-
156-
157- def test_train_fn_once ():
158- model , optimizer , device , loss_fn , batch = set_up ()
159- config = Namespace (use_amp = False )
160- engine = Engine (lambda e , b : train_function (config , e , b , model , loss_fn , optimizer , device ))
161- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
162- backward = MagicMock ()
163- optim = MagicMock ()
164- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED (once = 3 ), backward )
165- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED (once = 3 ), optim )
166- engine .run ([batch ] * 5 )
167- assert hasattr (engine .state , "backward_completed" )
168- assert hasattr (engine .state , "optim_step_completed" )
169- assert engine .state .backward_completed == 5
170- assert engine .state .optim_step_completed == 5
171- assert backward .call_count == 1
172- assert optim .call_count == 1
173- assert backward .called
174- assert optim .called
175-
176-
17793def test_evaluate_fn ():
17894 model , optimizer , device , loss_fn , batch = set_up ()
17995 engine = Engine (lambda e , b : 1 )
@@ -193,8 +109,6 @@ def test_create_trainers():
193109 )
194110 assert isinstance (trainer , Engine )
195111 assert isinstance (evaluator , Engine )
196- assert hasattr (trainer .state , "backward_completed" )
197- assert hasattr (trainer .state , "optim_step_completed" )
198112
199113
200114def test_get_default_parser ():
0 commit comments