@@ -527,15 +527,19 @@ def test_influences_lissa(
527527 influence_factors , x_train , y_train , mode = test_case .mode
528528 )
529529
530+ atol = 1e-5
531+ rtol = 1e-4
530532 assert torch .allclose (
531- influences_from_factors , approx_influences , atol = 1e-5 , rtol = 1e-4
533+ influences_from_factors , approx_influences , atol = atol , rtol = rtol
532534 )
533535
534536 approx_influences = approx_influences .cpu ().numpy ()
535537
536538 assert not np .any (np .isnan (approx_influences ))
537539
538- np .testing .assert_allclose (approx_influences , direct_influences , rtol = 1e-1 )
540+ np .testing .assert_allclose (
541+ approx_influences , direct_influences , atol = atol , rtol = rtol
542+ )
539543
540544 if test_case .mode == InfluenceMode .Up :
541545 assert approx_influences .shape == (
@@ -553,7 +557,9 @@ def test_influences_lissa(
553557 # check that influences are not all constant
554558 assert not np .all (approx_influences == approx_influences .item (0 ))
555559
556- np .testing .assert_allclose (approx_influences , direct_influences , rtol = 1e-1 )
560+ np .testing .assert_allclose (
561+ approx_influences , direct_influences , atol = atol , rtol = rtol
562+ )
557563
558564
559565@pytest .mark .parametrize (
@@ -674,6 +680,9 @@ def test_influences_ekfac(
674680 direct_sym_influences ,
675681 device : torch .device ,
676682):
683+ atol = 1e-6
684+ rtol = 1e-4
685+
677686 model , loss , x_train , y_train , x_test , y_test = model_and_data
678687
679688 train_dataloader = DataLoader (
@@ -731,8 +740,12 @@ def test_influences_ekfac(
731740 .numpy ()
732741 )
733742
734- np .testing .assert_allclose (ekfac_influence_values , influence_from_factors )
735- np .testing .assert_allclose (ekfac_influence_values , accumulated_inf_by_layer )
743+ np .testing .assert_allclose (
744+ ekfac_influence_values , influence_from_factors , atol = atol , rtol = rtol
745+ )
746+ np .testing .assert_allclose (
747+ ekfac_influence_values , accumulated_inf_by_layer , atol = atol , rtol = rtol
748+ )
736749 check_influence_correlations (
737750 direct_influences .numpy (), ekfac_influence_values , threshold = 0.94
738751 )
@@ -832,7 +845,7 @@ def test_influences_cg(
832845 .numpy ()
833846 )
834847 np .testing .assert_allclose (
835- single_influence , direct_factors [0 ], atol = 1e-6 , rtol = 1e-4
848+ single_influence [ 0 ] , direct_factors [0 ], atol = 1e-6 , rtol = 1e-4
836849 )
837850
838851
0 commit comments