Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c25e43f

Browse files
authored
Merge pull request #422 from urvashik/master
Completing the CNN/Dailymail summarization pipeline
2 parents c78abcd + 3f52fb9 commit c25e43f

File tree

4 files changed

+131
-11
lines changed

4 files changed

+131
-11
lines changed

tensor2tensor/data_generators/cnn_dailymail.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import hashlib
22+
import io
2323
import os
2424
import tarfile
25+
import hashlib
2526

2627
# Dependency imports
2728

@@ -46,7 +47,7 @@
4647
# Train/Dev/Test Splits for summarization data
4748
_TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
4849
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
49-
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
50+
_TEST_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt"
5051

5152

5253
# End-of-sentence marker.
@@ -128,9 +129,7 @@ def generate_hash(inp):
128129

129130
return filelist
130131

131-
132-
def example_generator(tmp_dir, is_training, sum_token):
133-
"""Generate examples."""
132+
def example_generator(all_files, urls_path, sum_token):
134133
def fix_run_on_sents(line):
135134
if u"@highlight" in line:
136135
return line
@@ -140,7 +139,6 @@ def fix_run_on_sents(line):
140139
return line
141140
return line + u"."
142141

143-
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
144142
filelist = example_splits(urls_path, all_files)
145143
story_summary_split_token = u" <summary> " if sum_token else " "
146144

@@ -170,13 +168,29 @@ def fix_run_on_sents(line):
170168

171169
yield " ".join(story) + story_summary_split_token + " ".join(summary)
172170

173-
174171
def _story_summary_split(story):
175172
split_str = u" <summary> "
176173
split_str_len = len(split_str)
177174
split_pos = story.find(split_str)
178175
return story[:split_pos], story[split_pos+split_str_len:] # story, summary
179176

177+
def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training):
178+
def write_to_file(all_files, urls_path, data_dir, filename):
179+
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
180+
for example in example_generator(all_files, urls_path, sum_token=True):
181+
story, summary = _story_summary_split(example)
182+
fstory.write(story+"\n")
183+
fsummary.write(summary+"\n")
184+
185+
filename = "cnndm.train" if is_training else "cnndm.dev"
186+
tf.logging.info("Writing %s" % filename)
187+
write_to_file(all_files, urls_path, data_dir, filename)
188+
189+
if not is_training:
190+
test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS)
191+
filename = "cnndm.test"
192+
tf.logging.info("Writing %s" % filename)
193+
write_to_file(all_files, test_urls_path, data_dir, filename)
180194

181195
@registry.register_problem
182196
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
@@ -219,10 +233,12 @@ def use_train_shards_for_dev(self):
219233
return False
220234

221235
def generator(self, data_dir, tmp_dir, is_training):
236+
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
222237
encoder = generator_utils.get_or_generate_vocab_inner(
223238
data_dir, self.vocab_file, self.targeted_vocab_size,
224-
example_generator(tmp_dir, is_training, sum_token=False))
225-
for example in example_generator(tmp_dir, is_training, sum_token=True):
239+
example_generator(all_files, urls_path, sum_token=False))
240+
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training)
241+
for example in example_generator(all_files, urls_path, sum_token=True):
226242
story, summary = _story_summary_split(example)
227243
encoded_summary = encoder.encode(summary) + [EOS]
228244
encoded_story = encoder.encode(story) + [EOS]

tensor2tensor/utils/decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def log_decode_results(inputs,
8383

8484
decoded_targets = None
8585
if identity_output:
86-
decoded_outputs = " ".join(map(str, outputs.flatten()))
86+
decoded_outputs = "".join(map(str, outputs.flatten()))
8787
if targets is not None:
88-
decoded_targets = " ".join(map(str, targets.flatten()))
88+
decoded_targets = "".join(map(str, targets.flatten()))
8989
else:
9090
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image))
9191
if targets is not None:
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
# Path to moses dir
4+
mosesdecoder=$1
5+
6+
# Path to file containing gold summaries, one per line
7+
targets_file=$2
8+
# Path to file containing model generated summaries, one per line
9+
decodes_file=$3
10+
11+
# Tokenize.
12+
perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $targets_file > $targets_file.tok
13+
perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $decodes_file > $decodes_file.tok
14+
15+
# Get rouge scores
16+
python get_rouge.py --decodes_filename $decodes_file.tok --targets_filename $targets_file.tok

tensor2tensor/utils/get_rouge.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Computing rouge scores using pyrouge."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import logging
24+
import shutil
25+
from tempfile import mkdtemp
26+
from pprint import pprint
27+
28+
# Dependency imports
29+
from pyrouge import Rouge155
30+
31+
import numpy as np
32+
import tensorflow as tf
33+
34+
FLAGS = tf.flags.FLAGS
35+
36+
tf.flags.DEFINE_string("decodes_filename", None, "File containing model generated summaries tokenized")
37+
tf.flags.DEFINE_string("targets_filename", None, "File containing model target summaries tokenized")
38+
39+
def write_to_file(filename, data):
40+
data = ".\n".join(data.split(". "))
41+
with open(filename, "w") as fp:
42+
fp.write(data)
43+
44+
def prep_data(decode_dir, target_dir):
45+
with open(FLAGS.decodes_filename, "rb") as fdecodes, open(FLAGS.targets_filename, "rb") as ftargets:
46+
for i, (d, t) in enumerate(zip(fdecodes, ftargets)):
47+
write_to_file(os.path.join(decode_dir, "rouge.%06d.txt" % (i+1)), d)
48+
write_to_file(os.path.join(target_dir, "rouge.A.%06d.txt" % (i+1)), t)
49+
50+
if (i+1 % 1000) == 0:
51+
tf.logging.into("Written %d examples to file" % i)
52+
53+
def main(_):
54+
rouge = Rouge155()
55+
rouge.log.setLevel(logging.ERROR)
56+
rouge.system_filename_pattern = "rouge.(\d+).txt"
57+
rouge.model_filename_pattern = "rouge.[A-Z].#ID#.txt"
58+
59+
tf.logging.set_verbosity(tf.logging.INFO)
60+
61+
tmpdir = mkdtemp()
62+
tf.logging.info("tmpdir: %s" % tmpdir)
63+
# system = decodes/predictions
64+
system_dir = os.path.join(tmpdir, 'system')
65+
# model = targets/gold
66+
model_dir = os.path.join(tmpdir, 'model')
67+
os.mkdir(system_dir)
68+
os.mkdir(model_dir)
69+
70+
rouge.system_dir = system_dir
71+
rouge.model_dir = model_dir
72+
73+
prep_data(rouge.system_dir, rouge.model_dir)
74+
75+
rouge_scores = rouge.convert_and_evaluate()
76+
rouge_scores = rouge.output_to_dict(rouge_scores)
77+
for prefix in ["rouge_1", "rouge_2", "rouge_l"]:
78+
for suffix in ["f_score", "precision", "recall"]:
79+
key = "_".join([prefix, suffix])
80+
tf.logging.info("%s: %.4f" % (key, rouge_scores[key]))
81+
82+
# clean up after pyrouge
83+
shutil.rmtree(tmpdir)
84+
shutil.rmtree(rouge._config_dir)
85+
shutil.rmtree(os.path.split(rouge._system_dir)[0])
86+
87+
if __name__=='__main__':
88+
tf.app.run()

0 commit comments

Comments
 (0)