1919from __future__ import division
2020from __future__ import print_function
2121
22- import hashlib
22+ import io
2323import os
2424import tarfile
25+ import hashlib
2526
2627# Dependency imports
2728
4647# Train/Dev/Test Splits for summarization data
4748_TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
4849_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
49- _TEST_URLS = "https://github. com/abisee/cnn-dailymail/blob /master/url_lists/all_test.txt"
50+ _TEST_URLS = "https://raw.githubusercontent. com/abisee/cnn-dailymail/master/url_lists/all_test.txt"
5051
5152
5253# End-of-sentence marker.
@@ -128,9 +129,7 @@ def generate_hash(inp):
128129
129130 return filelist
130131
131-
132- def example_generator (tmp_dir , is_training , sum_token ):
133- """Generate examples."""
132+ def example_generator (all_files , urls_path , sum_token ):
134133 def fix_run_on_sents (line ):
135134 if u"@highlight" in line :
136135 return line
@@ -140,7 +139,6 @@ def fix_run_on_sents(line):
140139 return line
141140 return line + u"."
142141
143- all_files , urls_path = _maybe_download_corpora (tmp_dir , is_training )
144142 filelist = example_splits (urls_path , all_files )
145143 story_summary_split_token = u" <summary> " if sum_token else " "
146144
@@ -170,13 +168,29 @@ def fix_run_on_sents(line):
170168
171169 yield " " .join (story ) + story_summary_split_token + " " .join (summary )
172170
173-
174171def _story_summary_split (story ):
175172 split_str = u" <summary> "
176173 split_str_len = len (split_str )
177174 split_pos = story .find (split_str )
178175 return story [:split_pos ], story [split_pos + split_str_len :] # story, summary
179176
177+ def write_raw_text_to_files (all_files , urls_path , data_dir , tmp_dir , is_training ):
178+ def write_to_file (all_files , urls_path , data_dir , filename ):
179+ with io .open (os .path .join (data_dir , filename + ".source" ), "w" ) as fstory , io .open (os .path .join (data_dir , filename + ".target" ), "w" ) as fsummary :
180+ for example in example_generator (all_files , urls_path , sum_token = True ):
181+ story , summary = _story_summary_split (example )
182+ fstory .write (story + "\n " )
183+ fsummary .write (summary + "\n " )
184+
185+ filename = "cnndm.train" if is_training else "cnndm.dev"
186+ tf .logging .info ("Writing %s" % filename )
187+ write_to_file (all_files , urls_path , data_dir , filename )
188+
189+ if not is_training :
190+ test_urls_path = generator_utils .maybe_download (tmp_dir , "all_test.txt" , _TEST_URLS )
191+ filename = "cnndm.test"
192+ tf .logging .info ("Writing %s" % filename )
193+ write_to_file (all_files , test_urls_path , data_dir , filename )
180194
181195@registry .register_problem
182196class SummarizeCnnDailymail32k (problem .Text2TextProblem ):
@@ -219,10 +233,12 @@ def use_train_shards_for_dev(self):
219233 return False
220234
221235 def generator (self , data_dir , tmp_dir , is_training ):
236+ all_files , urls_path = _maybe_download_corpora (tmp_dir , is_training )
222237 encoder = generator_utils .get_or_generate_vocab_inner (
223238 data_dir , self .vocab_file , self .targeted_vocab_size ,
224- example_generator (tmp_dir , is_training , sum_token = False ))
225- for example in example_generator (tmp_dir , is_training , sum_token = True ):
239+ example_generator (all_files , urls_path , sum_token = False ))
240+ write_raw_text_to_files (all_files , urls_path , data_dir , tmp_dir , is_training )
241+ for example in example_generator (all_files , urls_path , sum_token = True ):
226242 story , summary = _story_summary_split (example )
227243 encoded_summary = encoder .encode (summary ) + [EOS ]
228244 encoded_story = encoder .encode (story ) + [EOS ]
0 commit comments