Skip to content

Commit 9014ec4

Browse files
committed
Add tests
1 parent a28eda5 commit 9014ec4

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

diffsims/generators/simulation_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def calculate_diffraction2d(
206206
rotate_iter = rotate
207207
if show_progressbar:
208208
rotate_iter = tqdm(rotate_iter, desc=p.name, total=rotate.size)
209-
209+
210210
for rot in rotate_iter:
211211
# Calculate the reciprocal lattice vectors that intersect the Ewald sphere.
212212
(

diffsims/tests/generators/test_simulation_generator.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,63 @@ def test_same_simulation_results():
341341
)
342342
old_data = np.load(FILE1)
343343
np.testing.assert_allclose(new_data, old_data, atol=1e-8)
344+
345+
346+
def test_calculate_diffraction2d_progressbar_single_phase(capsys):
347+
gen = SimulationGenerator()
348+
phase = make_phase()
349+
phase.name = "test phase"
350+
rots = Rotation.random(10)
351+
352+
# no pbar
353+
sims = gen.calculate_diffraction2d(phase, rots, show_progressbar=False)
354+
355+
# Note: tqdm outputs to stderr by default
356+
captured = capsys.readouterr()
357+
assert captured.err == ""
358+
359+
# with pbar
360+
sims = gen.calculate_diffraction2d(phase, rots, show_progressbar=True)
361+
362+
captured = capsys.readouterr()
363+
expected = "test phase: 100%|██████████| 10/10" # also some more, but that is compute-time dependent
364+
# ignore possible flushing
365+
captured = captured.err.split("\r")[-1]
366+
assert captured[: len(expected)] == expected
367+
368+
369+
def test_calculate_diffraction2d_progressbar_multi_phase(capsys):
370+
gen = SimulationGenerator()
371+
phase1 = make_phase()
372+
phase1.name = "A"
373+
phase2 = make_phase()
374+
phase2.name = "B"
375+
rots = Rotation.random(10)
376+
377+
# no pbar
378+
sims = gen.calculate_diffraction2d(
379+
[phase1, phase2], [rots, rots], show_progressbar=False
380+
)
381+
382+
# Note: tqdm outputs to stderr by default
383+
captured = capsys.readouterr()
384+
assert captured.err == ""
385+
386+
# with pbar
387+
sims = gen.calculate_diffraction2d(
388+
[phase1, phase2], [rots, rots], show_progressbar=True
389+
)
390+
391+
captured = capsys.readouterr()
392+
expected1 = "A: 100%|██████████| 10/10 "
393+
expected2 = "B: 100%|██████████| 10/10 "
394+
# Find the correct output in the stream, i.e. final line containing the name of the phase
395+
captured1 = ""
396+
captured2 = ""
397+
for line in captured.err.split("\r"):
398+
if "A" in line:
399+
captured1 = line
400+
if "B" in line:
401+
captured2 = line
402+
assert captured1[: len(expected1)] == expected1
403+
assert captured2[: len(expected2)] == expected2

0 commit comments

Comments
 (0)