Skip to content

Commit af13c38

Browse files
committed
Update TensorFlow text generation example
1 parent 9ee6b74 commit af13c38

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

examples/tensorflow/text-generator/encoder.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22

33
# This file includes code which was modified from https://github.com/openai/gpt-2
44

5-
import tensorflow as tf
6-
import os
5+
import boto3
76
import json
8-
import regex as re
7+
import regex
98
from functools import lru_cache
10-
import requests
11-
import boto3
129

1310

1411
@lru_cache()
@@ -47,7 +44,7 @@ def __init__(self, encoder, bpe_merges, errors="replace"):
4744
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
4845
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
4946
self.cache = {}
50-
self.pat = re.compile(
47+
self.pat = regex.compile(
5148
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
5249
)
5350

@@ -94,7 +91,7 @@ def bpe(self, token):
9491

9592
def encode(self, text):
9693
bpe_tokens = []
97-
for token in re.findall(self.pat, text):
94+
for token in regex.findall(self.pat, text):
9895
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
9996
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
10097
return bpe_tokens
@@ -108,10 +105,14 @@ def decode(self, tokens):
108105
def get_encoder():
109106
s3 = boto3.client("s3")
110107
encoder = json.load(
111-
s3.get_object(Bucket="cortex-examples", Key="text-generator/gpt-2/encoder.json")["Body"]
108+
s3.get_object(Bucket="cortex-examples", Key="tensorflow/text-generator/gpt-2/encoder.json")[
109+
"Body"
110+
]
112111
)
113112
bpe_data = (
114-
s3.get_object(Bucket="cortex-examples", Key="text-generator/gpt-2/vocab.bpe")["Body"]
113+
s3.get_object(Bucket="cortex-examples", Key="tensorflow/text-generator/gpt-2/vocab.bpe")[
114+
"Body"
115+
]
115116
.read()
116117
.decode("utf-8")
117118
)

0 commit comments

Comments
 (0)