Skip to content

Commit dca3348

Browse files
committed
add helper stuff
1 parent 3446920 commit dca3348

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed

notebooks/seq2seq_utils.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
from matplotlib import pyplot as plt
2+
import tensorflow as tf
3+
from keras import backend as K
4+
from keras.layers import Input
5+
from keras.models import Model
6+
from IPython.display import SVG, display
7+
from keras.utils.vis_utils import model_to_dot
8+
import logging
9+
import numpy as np
10+
import dill as dpickle
11+
from annoy import AnnoyIndex
12+
from tqdm import tqdm
13+
from random import random
14+
15+
16+
def load_text_processor(fname='title_pp.dpkl'):
17+
"""
18+
Load preprocessors from disk.
19+
20+
Parameters
21+
----------
22+
fname: str
23+
file name of ktext.proccessor object
24+
25+
Returns
26+
-------
27+
num_tokens : int
28+
size of vocabulary loaded into ktext.processor
29+
pp : ktext.processor
30+
the processor you are trying to load
31+
32+
Typical Usage:
33+
-------------
34+
35+
num_decoder_tokens, title_pp = load_text_processor(fname='title_pp.dpkl')
36+
num_encoder_tokens, body_pp = load_text_processor(fname='body_pp.dpkl')
37+
38+
"""
39+
# Load files from disk
40+
with open(fname, 'rb') as f:
41+
pp = dpickle.load(f)
42+
43+
num_tokens = max(pp.id2token.keys()) + 1
44+
print(f'Size of vocabulary for {fname}: {num_tokens:,}')
45+
return num_tokens, pp
46+
47+
48+
def load_decoder_inputs(decoder_np_vecs='train_title_vecs.npy'):
49+
"""
50+
Load decoder inputs.
51+
52+
Parameters
53+
----------
54+
decoder_np_vecs : str
55+
filename of serialized numpy.array of decoder input (issue title)
56+
57+
Returns
58+
-------
59+
decoder_input_data : numpy.array
60+
The data fed to the decoder as input during training for teacher forcing.
61+
This is the same as `decoder_np_vecs` except the last position.
62+
decoder_target_data : numpy.array
63+
The data that the decoder data is trained to generate (issue title).
64+
Calculated by sliding `decoder_np_vecs` one position forward.
65+
66+
"""
67+
vectorized_title = np.load(decoder_np_vecs)
68+
# For Decoder Input, you don't need the last word as that is only for prediction
69+
# when we are training using Teacher Forcing.
70+
decoder_input_data = vectorized_title[:, :-1]
71+
72+
# Decoder Target Data Is Ahead By 1 Time Step From Decoder Input Data (Teacher Forcing)
73+
decoder_target_data = vectorized_title[:, 1:]
74+
75+
print(f'Shape of decoder input: {decoder_input_data.shape}')
76+
print(f'Shape of decoder target: {decoder_target_data.shape}')
77+
return decoder_input_data, decoder_target_data
78+
79+
80+
def load_encoder_inputs(encoder_np_vecs='train_body_vecs.npy'):
81+
"""
82+
Load variables & data that are inputs to encoder.
83+
84+
Parameters
85+
----------
86+
encoder_np_vecs : str
87+
filename of serialized numpy.array of encoder input (issue title)
88+
89+
Returns
90+
-------
91+
encoder_input_data : numpy.array
92+
The issue body
93+
doc_length : int
94+
The standard document length of the input for the encoder after padding
95+
the shape of this array will be (num_examples, doc_length)
96+
97+
"""
98+
vectorized_body = np.load(encoder_np_vecs)
99+
# Encoder input is simply the body of the issue text
100+
encoder_input_data = vectorized_body
101+
doc_length = encoder_input_data.shape[1]
102+
print(f'Shape of encoder input: {encoder_input_data.shape}')
103+
return encoder_input_data, doc_length
104+
105+
106+
def viz_model_architecture(model):
107+
"""Visualize model architecture in Jupyter notebook."""
108+
display(SVG(model_to_dot(model).create(prog='dot', format='svg')))
109+
110+
111+
def free_gpu_mem():
112+
"""Attempt to free gpu memory."""
113+
K.get_session().close()
114+
cfg = K.tf.ConfigProto()
115+
cfg.gpu_options.allow_growth = True
116+
K.set_session(K.tf.Session(config=cfg))
117+
118+
119+
def test_gpu():
120+
"""Run a toy computation task in tensorflow to test GPU."""
121+
config = tf.ConfigProto()
122+
config.gpu_options.allow_growth = True
123+
session = tf.Session(config=config)
124+
hello = tf.constant('Hello, TensorFlow!')
125+
print(session.run(hello))
126+
127+
128+
def plot_model_training_history(history_object):
129+
"""Plots model train vs. validation loss."""
130+
plt.title('model accuracy')
131+
plt.ylabel('accuracy')
132+
plt.xlabel('epoch')
133+
plt.plot(history_object.history['loss'])
134+
plt.plot(history_object.history['val_loss'])
135+
plt.legend(['train', 'test'], loc='upper left')
136+
plt.show()
137+
138+
139+
def extract_encoder_model(model):
140+
"""
141+
Extract the encoder from the original Sequence to Sequence Model.
142+
143+
Returns a keras model object that has one input (body of issue) and one
144+
output (encoding of issue, which is the last hidden state).
145+
146+
Input:
147+
-----
148+
model: keras model object
149+
150+
Returns:
151+
-----
152+
keras model object
153+
154+
"""
155+
encoder_model = model.get_layer('Encoder-Model')
156+
return encoder_model
157+
158+
159+
def extract_decoder_model(model):
160+
"""
161+
Extract the decoder from the original model.
162+
163+
Inputs:
164+
------
165+
model: keras model object
166+
167+
Returns:
168+
-------
169+
A Keras model object with the following inputs and outputs:
170+
171+
Inputs:
172+
1: the embedding index for the last predicted word, or the <Start> indicator
173+
2: the last hidden state, or in the case of the first word the hidden state from the encoder
174+
175+
Outputs:
176+
1. Prediction (class probabilities) for the next word
177+
2. The hidden state of the decoder, to be fed back into the decoder at the next time step
178+
179+
Implementation Notes:
180+
----------------------
181+
Must extract relevant layers and reconstruct part of the computation graph
182+
to allow for different inputs as we are not going to use teacher forcing at
183+
inference time.
184+
185+
"""
186+
# the latent dimension is the same throughout the architecture so we are going to
187+
# cheat and grab the latent dimension of the embedding because that is the same as what is
188+
# output from the decoder
189+
latent_dim = model.get_layer('Decoder-Word-Embedding').output_shape[-1]
190+
191+
# Reconstruct the input into the decoder
192+
decoder_inputs = model.get_layer('Decoder-Input').input
193+
dec_emb = model.get_layer('Decoder-Word-Embedding')(decoder_inputs)
194+
dec_bn = model.get_layer('Decoder-Batchnorm-1')(dec_emb)
195+
196+
# Instead of setting the intial state from the encoder and forgetting about it, during inference
197+
# we are not doing teacher forcing, so we will have to have a feedback loop from predictions back into
198+
# the GRU, thus we define this input layer for the state so we can add this capability
199+
gru_inference_state_input = Input(shape=(latent_dim,), name='hidden_state_input')
200+
201+
# we need to reuse the weights that is why we are getting this
202+
# If you inspect the decoder GRU that we created for training, it will take as input
203+
# 2 tensors -> (1) is the embedding layer output for the teacher forcing
204+
# (which will now be the last step's prediction, and will be _start_ on the first time step)
205+
# (2) is the state, which we will initialize with the encoder on the first time step, but then
206+
# grab the state after the first prediction and feed that back in again.
207+
gru_out, gru_state_out = model.get_layer('Decoder-GRU')([dec_bn, gru_inference_state_input])
208+
209+
# Reconstruct dense layers
210+
dec_bn2 = model.get_layer('Decoder-Batchnorm-2')(gru_out)
211+
dense_out = model.get_layer('Final-Output-Dense')(dec_bn2)
212+
decoder_model = Model([decoder_inputs, gru_inference_state_input],
213+
[dense_out, gru_state_out])
214+
return decoder_model
215+
216+
217+
class Seq2Seq_Inference(object):
218+
def __init__(self,
219+
encoder_preprocessor,
220+
decoder_preprocessor,
221+
seq2seq_model):
222+
223+
self.pp_body = encoder_preprocessor
224+
self.pp_title = decoder_preprocessor
225+
self.seq2seq_model = seq2seq_model
226+
self.encoder_model = extract_encoder_model(seq2seq_model)
227+
self.decoder_model = extract_decoder_model(seq2seq_model)
228+
self.default_max_len_title = self.pp_title.padding_maxlen
229+
self.nn = None
230+
self.rec_df = None
231+
232+
def generate_issue_title(self,
233+
raw_input_text,
234+
max_len_title=None):
235+
"""
236+
Use the seq2seq model to generate a title given the body of an issue.
237+
238+
Inputs
239+
------
240+
raw_input: str
241+
The body of the issue text as an input string
242+
243+
max_len_title: int (optional)
244+
The maximum length of the title the model will generate
245+
246+
"""
247+
if max_len_title is None:
248+
max_len_title = self.default_max_len_title
249+
# Seed For _start_ token
250+
raw_tokenized = self.pp_body.transform([raw_input_text])
251+
body_encoding = self.encoder_model.predict(raw_tokenized)
252+
# we want to save the encoder's embedding before its updated by decoder
253+
# because we can use that as an embedding for the enocder's input
254+
# (the issue body)
255+
original_body_encoding = body_encoding
256+
state_value = np.array(self.pp_title.token2id['_start_']).reshape(1, 1)
257+
258+
decoded_sentence = []
259+
stop_condition = False
260+
while not stop_condition:
261+
preds, st = self.decoder_model.predict([state_value, body_encoding])
262+
263+
# We are going to ignore indices 0 (padding) and indices 1 (unknown)
264+
# Argmax will return the integer index corresponding to the
265+
# prediction + 2 b/c we chopped off first two
266+
pred_idx = np.argmax(preds[:, :, 2:]) + 2
267+
268+
# retrieve word from index prediction
269+
pred_word_str = self.pp_title.id2token[pred_idx]
270+
271+
if pred_word_str == '_end_' or len(decoded_sentence) >= max_len_title:
272+
stop_condition = True
273+
break
274+
decoded_sentence.append(pred_word_str)
275+
276+
# update the decoder for the next word
277+
body_encoding = st
278+
state_value = np.array(pred_idx).reshape(1, 1)
279+
280+
return original_body_encoding, ' '.join(decoded_sentence)
281+
282+
283+
def print_example(self,
284+
i,
285+
body_text,
286+
title_text,
287+
url,
288+
threshold):
289+
"""
290+
Prints an example of the model's prediction for manual inspection.
291+
"""
292+
293+
print('\n\n==============================================')
294+
print(f'============== Example # {i} =================\n')
295+
print(url)
296+
print(f"Issue Body:\n {body_text} \n")
297+
298+
print(f"Original Title:\n {title_text}")
299+
300+
emb, gen_title = self.generate_issue_title(body_text)
301+
print(f"\n****** Machine Generated Title (Prediction) ******:\n {gen_title}")
302+
303+
if self.nn:
304+
# return neighbors and distances
305+
n, d = self.nn.get_nns_by_vector(emb.flatten(), n=3,
306+
include_distances=True)
307+
neighbors = n[1:]
308+
dist = d[1:]
309+
310+
if min(dist) <= threshold:
311+
cols = ['issue_url', 'issue_title', 'body']
312+
dfcopy = self.rec_df.iloc[neighbors][cols].copy(deep=True)
313+
dfcopy['dist'] = dist
314+
similar_issues_df = dfcopy.query(f'dist <= {threshold}')
315+
316+
print("\n**** Similar Issues (using encoder embedding) ****:\n")
317+
display(similar_issues_df)
318+
319+
320+
def demo_model_predictions(self,
321+
n,
322+
issue_df,
323+
threshold=1):
324+
"""
325+
Pick n random Issues and display predictions.
326+
327+
Input:
328+
------
329+
n : int
330+
Number of issues to display from issue_df
331+
issue_df : pandas DataFrame
332+
DataFrame that contains two columns: `body` and `issue_title`.
333+
threshold : float
334+
distance threshold for recommendation of similar issues.
335+
336+
Returns:
337+
--------
338+
None
339+
Prints the original issue body and the model's prediction.
340+
"""
341+
# Extract body and title from DF
342+
body_text = issue_df.body.tolist()
343+
title_text = issue_df.issue_title.tolist()
344+
url = issue_df.issue_url.tolist()
345+
346+
demo_list = np.random.randint(low=1, high=len(body_text), size=n)
347+
for i in demo_list:
348+
self.print_example(i,
349+
body_text=body_text[i],
350+
title_text=title_text[i],
351+
url=url[i],
352+
threshold=threshold)
353+
354+
def prepare_recommender(self, vectorized_array, original_df):
355+
raise NotImplementedError
356+
# TODO: verify vectorized_array == original_df
357+
self.rec_df = original_df
358+
emb = self.encoder_model.predict(x=vectorized_array,
359+
batch_size=vectorized_array.shape[0]//100)
360+
361+
f = emb.shape[1]
362+
self.nn = AnnoyIndex(f)
363+
logging.warning('Adding embeddings')
364+
for i in tqdm(range(len(emb))):
365+
self.nn.add_item(i, emb[i])
366+
logging.warning('Building trees for similarity lookup.')
367+
self.nn.build(80)
368+
return self.nn
369+
370+
def set_recsys_data(self, original_df):
371+
raise NotImplementedError
372+
self.rec_df = original_df
373+
374+
def set_recsys_annoyobj(self, annoyobj):
375+
raise NotImplementedError
376+
self.nn = annoyobj

0 commit comments

Comments
 (0)