|
43 | 43 | ) |
44 | 44 | _NUM_CORRUPT_IMAGES = 1738 |
45 | 45 | _DESCRIPTION = ( |
46 | | - "A large set of images of cats and dogs. " |
47 | | - "There are %d corrupted images that are dropped." % _NUM_CORRUPT_IMAGES |
| 46 | + "A large set of images of cats and dogs. " |
| 47 | + "There are %d corrupted images that are dropped." % _NUM_CORRUPT_IMAGES |
48 | 48 | ) |
49 | 49 |
|
50 | 50 | _NAME_RE = re.compile(r"^PetImages[\\/](Cat|Dog)[\\/]\d+\.jpg$") |
51 | 51 |
|
52 | 52 |
|
53 | 53 | class CatsVsDogs(tfds.core.GeneratorBasedBuilder): |
54 | | - """Cats vs Dogs.""" |
55 | | - |
56 | | - VERSION = tfds.core.Version("4.0.1") |
57 | | - RELEASE_NOTES = { |
58 | | - "4.0.0": "New split API (https://tensorflow.org/datasets/splits)", |
59 | | - "4.0.1": ( |
60 | | - "Recoding images in generator to fix corrupt JPEG data warnings" |
61 | | - " (https://github.com/tensorflow/datasets/issues/2188)" |
| 54 | + """Cats vs Dogs.""" |
| 55 | + |
| 56 | + VERSION = tfds.core.Version("4.0.1") |
| 57 | + RELEASE_NOTES = { |
| 58 | + "4.0.0": "New split API (https://tensorflow.org/datasets/splits)", |
| 59 | + "4.0.1": ( |
| 60 | + "Recoding images in generator to fix corrupt JPEG data warnings" |
| 61 | + " (https://github.com/tensorflow/datasets/issues/2188)" |
| 62 | + ), |
| 63 | + } |
| 64 | + |
| 65 | + def _info(self): |
| 66 | + return tfds.core.DatasetInfo( |
| 67 | + builder=self, |
| 68 | + description=_DESCRIPTION, |
| 69 | + features=tfds.features.FeaturesDict({ |
| 70 | + "image": tfds.features.Image(), |
| 71 | + "image/filename": tfds.features.Text(), # eg 'PetImages/Dog/0.jpg' |
| 72 | + "label": tfds.features.ClassLabel(names=["cat", "dog"]), |
| 73 | + }), |
| 74 | + supervised_keys=("image", "label"), |
| 75 | + homepage=( |
| 76 | + "https://www.microsoft.com/en-us/download/details.aspx?id=54765" |
62 | 77 | ), |
63 | | - } |
64 | | - |
65 | | - def _info(self): |
66 | | - return tfds.core.DatasetInfo( |
67 | | - builder=self, |
68 | | - description=_DESCRIPTION, |
69 | | - features=tfds.features.FeaturesDict({ |
70 | | - "image": tfds.features.Image(), |
71 | | - "image/filename": tfds.features.Text(), # eg 'PetImages/Dog/0.jpg' |
72 | | - "label": tfds.features.ClassLabel(names=["cat", "dog"]), |
73 | | - }), |
74 | | - supervised_keys=("image", "label"), |
75 | | - homepage=( |
76 | | - "https://www.microsoft.com/en-us/download/details.aspx?id=54765" |
77 | | - ), |
78 | | - citation=_CITATION, |
79 | | - ) |
80 | | - |
81 | | - def _split_generators(self, dl_manager): |
82 | | - path = dl_manager.download(_URL) |
83 | | - |
84 | | - # There is no predefined train/val/test split for this dataset. |
85 | | - return [ |
86 | | - tfds.core.SplitGenerator( |
87 | | - name=tfds.Split.TRAIN, |
88 | | - gen_kwargs={ |
89 | | - "archive": dl_manager.iter_archive(path), |
90 | | - }, |
91 | | - ), |
92 | | - ] |
93 | | - |
94 | | - def _generate_examples(self, archive): |
95 | | - """Generate Cats vs Dogs images and labels given a directory path.""" |
96 | | - num_skipped = 0 |
97 | | - for fname, fobj in archive: |
98 | | - norm_fname = os.path.normpath(fname) |
99 | | - res = _NAME_RE.match(norm_fname) |
100 | | - if not res: # README file, ... |
101 | | - continue |
102 | | - label = res.group(1).lower() |
103 | | - if tf.compat.as_bytes("JFIF") not in fobj.peek(10): |
104 | | - num_skipped += 1 |
105 | | - continue |
106 | | - |
107 | | - # Some images caused 'Corrupt JPEG data...' messages during training or |
108 | | - # any other iteration recoding them once fixes the issue (discussion: |
109 | | - # https://github.com/tensorflow/datasets/issues/2188). |
110 | | - # Those messages are now displayed when generating the dataset instead. |
111 | | - img_data = fobj.read() |
112 | | - img_tensor = tf.image.decode_image(img_data) |
113 | | - img_recoded = tf.io.encode_jpeg(img_tensor) |
114 | | - |
115 | | - # Converting the recoded image back into a zip file container. |
116 | | - buffer = io.BytesIO() |
117 | | - with zipfile.ZipFile(buffer, "w") as new_zip: |
118 | | - new_zip.writestr(norm_fname, img_recoded.numpy()) |
119 | | - new_fobj = zipfile.ZipFile(buffer).open(norm_fname) |
120 | | - |
121 | | - record = { |
122 | | - "image": new_fobj, |
123 | | - "image/filename": norm_fname, |
124 | | - "label": label, |
125 | | - } |
126 | | - yield norm_fname, record |
127 | | - |
128 | | - if num_skipped != _NUM_CORRUPT_IMAGES: |
129 | | - raise ValueError( |
130 | | - "Expected %d corrupt images, but found %d" |
131 | | - % (_NUM_CORRUPT_IMAGES, num_skipped) |
132 | | - ) |
133 | | - logging.warning("%d images were corrupted and were skipped", num_skipped) |
| 78 | + citation=_CITATION, |
| 79 | + ) |
| 80 | + |
| 81 | + def _split_generators(self, dl_manager): |
| 82 | + path = dl_manager.download(_URL) |
| 83 | + |
| 84 | + # There is no predefined train/val/test split for this dataset. |
| 85 | + return [ |
| 86 | + tfds.core.SplitGenerator( |
| 87 | + name=tfds.Split.TRAIN, |
| 88 | + gen_kwargs={ |
| 89 | + "archive": dl_manager.iter_archive(path), |
| 90 | + }, |
| 91 | + ), |
| 92 | + ] |
| 93 | + |
| 94 | + def _generate_examples(self, archive): |
| 95 | + """Generate Cats vs Dogs images and labels given a directory path.""" |
| 96 | + num_skipped = 0 |
| 97 | + for fname, fobj in archive: |
| 98 | + norm_fname = os.path.normpath(fname) |
| 99 | + res = _NAME_RE.match(norm_fname) |
| 100 | + if not res: # README file, ... |
| 101 | + continue |
| 102 | + label = res.group(1).lower() |
| 103 | + if tf.compat.as_bytes("JFIF") not in fobj.peek(10): |
| 104 | + num_skipped += 1 |
| 105 | + continue |
| 106 | + |
| 107 | + # Some images caused 'Corrupt JPEG data...' messages during training or |
| 108 | + # any other iteration recoding them once fixes the issue (discussion: |
| 109 | + # https://github.com/tensorflow/datasets/issues/2188). |
| 110 | + # Those messages are now displayed when generating the dataset instead. |
| 111 | + img_data = fobj.read() |
| 112 | + img_tensor = tf.image.decode_image(img_data) |
| 113 | + img_recoded = tf.io.encode_jpeg(img_tensor) |
| 114 | + |
| 115 | + # Converting the recoded image back into a zip file container. |
| 116 | + buffer = io.BytesIO() |
| 117 | + with zipfile.ZipFile(buffer, "w") as new_zip: |
| 118 | + new_zip.writestr(norm_fname, img_recoded.numpy()) |
| 119 | + new_fobj = zipfile.ZipFile(buffer).open(norm_fname) |
| 120 | + |
| 121 | + record = { |
| 122 | + "image": new_fobj, |
| 123 | + "image/filename": norm_fname, |
| 124 | + "label": label, |
| 125 | + } |
| 126 | + yield norm_fname, record |
| 127 | + |
| 128 | + if num_skipped != _NUM_CORRUPT_IMAGES: |
| 129 | + raise ValueError( |
| 130 | + "Expected %d corrupt images, but found %d" |
| 131 | + % (_NUM_CORRUPT_IMAGES, num_skipped) |
| 132 | + ) |
| 133 | + logging.warning("%d images were corrupted and were skipped", num_skipped) |
0 commit comments