|
67 | 67 | except ImportError: # pragma: no cover |
68 | 68 | have_numpy = False |
69 | 69 |
|
70 | | -# tensorflow |
71 | | -try: |
72 | | - import tensorflow as tf |
73 | | - |
74 | | - have_tensorflow = True |
75 | | -except ImportError: # pragma: no cover |
76 | | - have_tensorflow = False |
77 | | - |
78 | 70 | # pytorch |
79 | 71 | try: |
80 | 72 | import torch |
|
83 | 75 | except ImportError: # pragma: no cover |
84 | 76 | have_pytorch = False |
85 | 77 |
|
| 78 | +# tensorflow |
| 79 | +try: |
| 80 | + import tensorflow as tf |
| 81 | + |
| 82 | + have_tensorflow = True |
| 83 | +except ImportError: # pragma: no cover |
| 84 | + have_tensorflow = False |
| 85 | + |
86 | 86 |
|
87 | 87 | default_seed = random.Random().getrandbits(32) |
88 | 88 |
|
@@ -196,17 +196,17 @@ def _reseed(config: Config, offset: int = 0) -> int: |
196 | 196 | else: |
197 | 197 | np_random.set_state(np_random_states[numpy_seed]) |
198 | 198 |
|
| 199 | + if have_pytorch: # pragma: no branch |
| 200 | + torch.manual_seed(seed) |
| 201 | + if torch.cuda.is_available(): # Also seed CUDA if available |
| 202 | + torch.cuda.manual_seed_all(seed) |
| 203 | + |
199 | 204 | if have_tensorflow: # pragma: no branch |
200 | 205 | tf.random.set_seed(seed) |
201 | 206 | # TensorFlow 1.x compatibility |
202 | 207 | if hasattr(tf, "compat"): |
203 | 208 | tf.compat.v1.set_random_seed(seed) |
204 | 209 |
|
205 | | - if have_pytorch: # pragma: no branch |
206 | | - torch.manual_seed(seed) |
207 | | - if torch.cuda.is_available(): # Also seed CUDA if available |
208 | | - torch.cuda.manual_seed_all(seed) |
209 | | - |
210 | 210 | if entrypoint_reseeds is None: |
211 | 211 | eps = entry_points(group="pytest_randomly.random_seeder") |
212 | 212 | entrypoint_reseeds = [e.load() for e in eps] |
|
0 commit comments