diff --git a/machine_translation/train_transformer_tf2.py b/machine_translation/train_transformer_tf2.py index b53f371..5b3d9e5 100644 --- a/machine_translation/train_transformer_tf2.py +++ b/machine_translation/train_transformer_tf2.py @@ -28,7 +28,8 @@ def maybe_download_and_read_file(url, filename): """ if not os.path.exists(filename): session = requests.Session() - response = session.get(url, stream=True) + response = session.get(url, stream=True, + headers={'User-Agent': 'Chrome/91.0.4472.106'}) CHUNK_SIZE = 32768 with open(filename, "wb") as f: @@ -37,7 +38,6 @@ def maybe_download_and_read_file(url, filename): f.write(chunk) zipf = ZipFile(filename) - filename = zipf.namelist() with zipf.open('fra.txt') as f: lines = f.read() @@ -73,7 +73,7 @@ def normalize_string(s): return s -raw_data_en, raw_data_fr = list(zip(*raw_data)) +raw_data_en, raw_data_fr, _ = list(zip(*raw_data)) raw_data_en = [normalize_string(data) for data in raw_data_en] raw_data_fr_in = [' ' + normalize_string(data) for data in raw_data_fr] raw_data_fr_out = [normalize_string(data) + ' ' for data in raw_data_fr]