|
26 | 26 | import warnings |
27 | 27 | from functools import partial |
28 | 28 | from itertools import product |
| 29 | +import io |
29 | 30 | import pathlib |
30 | 31 |
|
31 | 32 | import numpy as np |
@@ -523,34 +524,41 @@ def validate_affine_deprecated(self, imaker, params): |
523 | 524 | img.get_affine() |
524 | 525 |
|
525 | 526 |
|
526 | | -class SerializeMixin(object): |
527 | | - def validate_to_bytes(self, imaker, params): |
| 527 | +class SerializeMixin: |
| 528 | + def validate_to_from_stream(self, imaker, params): |
528 | 529 | img = imaker() |
529 | | - serialized = img.to_bytes() |
530 | | - with InTemporaryDirectory(): |
531 | | - fname = 'img' + self.standard_extension |
532 | | - img.to_filename(fname) |
533 | | - with open(fname, 'rb') as fobj: |
534 | | - file_contents = fobj.read() |
535 | | - assert serialized == file_contents |
| 530 | + klass = getattr(self, 'klass', img.__class__) |
| 531 | + stream = io.BytesIO() |
| 532 | + img.to_stream(stream) |
| 533 | + |
| 534 | + rt_img = klass.from_stream(stream) |
| 535 | + assert self._header_eq(img.header, rt_img.header) |
| 536 | + assert np.array_equal(img.get_fdata(), rt_img.get_fdata()) |
536 | 537 |
|
537 | | - def validate_from_bytes(self, imaker, params): |
| 538 | + def validate_file_stream_equivalence(self, imaker, params): |
538 | 539 | img = imaker() |
539 | 540 | klass = getattr(self, 'klass', img.__class__) |
540 | 541 | with InTemporaryDirectory(): |
541 | 542 | fname = 'img' + self.standard_extension |
542 | 543 | img.to_filename(fname) |
543 | 544 |
|
544 | | - all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}] |
545 | | - for img_params in all_images: |
546 | | - img_a = klass.from_filename(img_params['fname']) |
547 | | - with open(img_params['fname'], 'rb') as fobj: |
548 | | - img_b = klass.from_bytes(fobj.read()) |
| 545 | + with open("stream", "wb") as fobj: |
| 546 | + img.to_stream(fobj) |
549 | 547 |
|
550 | | - assert self._header_eq(img_a.header, img_b.header) |
| 548 | + # Check that writing gets us the same thing |
| 549 | + contents1 = pathlib.Path(fname).read_bytes() |
| 550 | + contents2 = pathlib.Path("stream").read_bytes() |
| 551 | + assert contents1 == contents2 |
| 552 | + |
| 553 | + # Check that reading gets us the same thing |
| 554 | + img_a = klass.from_filename(fname) |
| 555 | + with open(fname, "rb") as fobj: |
| 556 | + img_b = klass.from_stream(fobj) |
| 557 | + # This needs to happen while the filehandle is open |
551 | 558 | assert np.array_equal(img_a.get_fdata(), img_b.get_fdata()) |
552 | | - del img_a |
553 | | - del img_b |
| 559 | + assert self._header_eq(img_a.header, img_b.header) |
| 560 | + del img_a |
| 561 | + del img_b |
554 | 562 |
|
555 | 563 | def validate_to_from_bytes(self, imaker, params): |
556 | 564 | img = imaker() |
|
0 commit comments