Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
a37e5ed
create data loader
strutive07 Aug 26, 2019
75c0ba4
add word2idx, idx2word, tokenizer typo fix
strutive07 Aug 27, 2019
956fcc5
transformer model implementation. not tested
strutive07 Aug 31, 2019
a65d7e6
fix typo
strutive07 Aug 31, 2019
54248f1
fix typo: scaled dot product attention shape typo, Embedding paramete…
strutive07 Sep 2, 2019
5aca0b5
Update greetings.yml
strutive07 Sep 5, 2019
77292e7
fix typo, add training mode to dropout layer
strutive07 Sep 8, 2019
8e7eae7
add sequence save/load to pickle
strutive07 Sep 8, 2019
c691957
Create utils.py
strutive07 Sep 8, 2019
888897f
Create train.py
strutive07 Sep 8, 2019
3c6aac7
change dir to Relative path
strutive07 Sep 9, 2019
f4942c2
remove f string to run in python 3.5
strutive07 Sep 9, 2019
02044fc
move Trainer class to utils
strutive07 Sep 9, 2019
7c7311e
fix typo
strutive07 Sep 9, 2019
d9e7dc5
env install script
strutive07 Sep 9, 2019
d121fb6
fix imports
strutive07 Sep 9, 2019
d9e8cfc
fix imports
strutive07 Sep 9, 2019
391e9dc
add distributed training
strutive07 Sep 10, 2019
5c755c4
distributed dataset config
strutive07 Sep 11, 2019
a7470b8
fix typo
strutive07 Sep 11, 2019
b0502bf
add reducemean to distributed loss
strutive07 Sep 11, 2019
5024d3f
disable warning message
strutive07 Sep 11, 2019
e8fc019
add print global batch size
strutive07 Sep 11, 2019
7ef9e38
change function to tf.function
strutive07 Sep 11, 2019
fb0165b
remove tf.function
strutive07 Sep 11, 2019
1e02598
change loss
strutive07 Sep 12, 2019
cf8e7fd
fix import error
strutive07 Sep 12, 2019
55a947d
Update train.py
strutive07 Sep 13, 2019
b6f3055
Update distributed_train.py
strutive07 Sep 13, 2019
e305ddf
Update README.md
strutive07 Sep 23, 2019
977b556
Change preprocess from word tokenizing to byte pair encoding
strutive07 Sep 28, 2019
931f187
Update distributed_train.py
strutive07 Sep 28, 2019
cdbcd7d
Update distributed_train.py
strutive07 Sep 28, 2019
b28cce8
remove f string
strutive07 Sep 28, 2019
ca150e7
remove deprecated variables - vocab size
strutive07 Sep 28, 2019
b321751
typo fix
strutive07 Sep 28, 2019
2e74357
move create dataset to data loader
strutive07 Sep 28, 2019
b1dafa0
fix typo
strutive07 Sep 28, 2019
0f2761d
move dataset to data loader in single gpu train
strutive07 Sep 28, 2019
0cd882d
tmp py
strutive07 Sep 29, 2019
d1cbce2
add batch option to dataset
strutive07 Sep 29, 2019
196be99
delete tmp files
strutive07 Oct 3, 2019
43235f9
Merge branch 'feature/change_preprocessing_from_word_tokenize_to_bpe'
strutive07 Oct 3, 2019
d07ad2a
add test codes
strutive07 Oct 3, 2019
370dae5
fix paths
strutive07 Oct 3, 2019
830c705
fix dataset ranges
strutive07 Oct 3, 2019
cef536d
load checkpoint from None optimizer
strutive07 Oct 3, 2019
2fef613
add data encoder
strutive07 Oct 3, 2019
b03194a
fix typo
strutive07 Oct 3, 2019
e36c015
add multiprocessing
strutive07 Oct 3, 2019
bc6d1f6
remove multiprocess
strutive07 Oct 3, 2019
c300569
Add paper review as pdf file
strutive07 Oct 4, 2019
fbc575e
Update README.md
strutive07 Oct 4, 2019
60a9c8f
Update README.md
strutive07 Oct 4, 2019
1950cd9
add label smoothing, refactoring trainer
strutive07 Oct 13, 2019
87f83a1
transformer tf 2.0 colab guide
strutive07 Oct 20, 2019
ff20dbd
Merge pull request #6 from strutive07/feature/test_and_evaluation
strutive07 Oct 20, 2019
643b77f
Update README.md
strutive07 Oct 20, 2019
8e7a902
add BLEU score
strutive07 Oct 20, 2019
76e8ad6
Merge remote-tracking branch 'origin/master'
strutive07 Oct 20, 2019
3ad8d10
add requriments
strutive07 Oct 20, 2019
f03526e
Add custom dataset option to data loader. Change README
strutive07 Oct 20, 2019
ae2f13f
Update README.md
strutive07 Oct 20, 2019
100260a
Update README.md
strutive07 Oct 21, 2019
3b1ced1
bleu score calculator
strutive07 Dec 1, 2019
c0be89a
change tensorflwo version
strutive07 Dec 15, 2019
00c3e67
add train all dataset mode
strutive07 Dec 15, 2019
4c75d9c
change cuda, tensorflow version
strutive07 Dec 15, 2019
f50cbdc
fix: trainer util option 변경
strutive07 Jan 19, 2020
76a3601
fix bpe data loader 추가
strutive07 Jan 19, 2020
72f5267
fix: move sequence to text to translate util
strutive07 Jan 19, 2020
1c5b08b
Merge pull request #8 from strutive07/test_in_tf-2.1rc1
strutive07 Jan 19, 2020
7a01d40
Create .deepsource.toml
strutive07 Jan 20, 2020
df9bace
remove wild card import
strutive07 Jan 27, 2020
4b6272d
sort imports
strutive07 Jan 27, 2020
56d64bb
bleu calculator wrap to python funciton
strutive07 Jan 27, 2020
8b81080
clean unused imports
strutive07 Jan 27, 2020
98f8a29
remove unused variables
strutive07 Jan 27, 2020
41d5e0d
solve: File opened without the with statement PTC-W0010
strutive07 Jan 27, 2020
1b522b7
solve: Unnecessary else / elif PYL-R1705
strutive07 Jan 27, 2020
915400b
solve: Unnecessary else / elif PYL-R1705
strutive07 Jan 27, 2020
a9b396f
reformat model file
strutive07 Jan 27, 2020
6b29c8d
solve: Re-defined variable from outer scope PYL-W0621
strutive07 Jan 27, 2020
7c9f727
convert class name to camelcase
strutive07 Jan 27, 2020
54493c4
remove unused parameters
strutive07 Jan 27, 2020
1cf4f97
solve: Module imported but unused PYL-W0611
strutive07 Jan 27, 2020
f506400
solve: Consider using in PYL-R1714
strutive07 Jan 27, 2020
f634ae6
solve: Trailing comma tuple detected PYL-R1707
strutive07 Jan 27, 2020
87fe30b
solve: Multiple spaces found after operator FLK-E222
strutive07 Jan 27, 2020
9c73654
solve: Expected 2 blank lines, found 0 FLK-E302
strutive07 Jan 27, 2020
9896785
remove unused docs
strutive07 Jan 27, 2020
aed3b8b
remove variable name 'input'
strutive07 Jan 27, 2020
0e96fef
remove variable name 'input'
strutive07 Jan 27, 2020
58c68d6
solve: Blank line contains whitespace FLK-W293
strutive07 Jan 27, 2020
54ebd5f
fix too long lines
strutive07 Jan 27, 2020
a8b4fda
add deepsource icon to readme
strutive07 Jan 27, 2020
454e3b0
change model parameter name
strutive07 Jan 27, 2020
d8036f8
Merge pull request #10 from strutive07/clean_code
strutive07 Jan 27, 2020
b4d88f1
add bpe download link to README
strutive07 Feb 7, 2020
9502c8e
fix: typo in readme
strutive07 Feb 7, 2020
d9d836e
fix: move standard download method
strutive07 Feb 7, 2020
fef066d
Merge pull request #12 from strutive07/feature/add_bpe_model_file_to_…
strutive07 Feb 7, 2020
5011b77
add new line to download links
strutive07 Feb 7, 2020
436c4db
change encoder residual connection error
strutive07 Apr 30, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .deepsource.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
version = 1

[[analyzers]]
name = "python"
enabled = true
max_line_length = 120
runtime_version = "3.x.x"
13 changes: 13 additions & 0 deletions .github/workflows/greetings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Greetings

on: [pull_request, issues]

jobs:
greeting:
runs-on: ubuntu-latest
steps:
- uses: actions/first-interaction@v1
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
issue-message: 'Message that will be displayed on users'' first issue'
pr-message: 'Message that will be displayed on users'' first pr'
220 changes: 220 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@

# Created by https://www.gitignore.io/api/vim,macos,python,windows,virtualenv
# Edit at https://www.gitignore.io/?templates=vim,macos,python,windows,virtualenv

### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride

# Icon must end with two \r
Icon

# Thumbnails
._*

# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

### Vim ###
# Swap
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]

# Session
Session.vim
Sessionx.vim

# Temporary
.netrwhist
*~
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~

### VirtualEnv ###
# Virtualenv
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
[Bb]in
[Ii]nclude
[Ll]ib
[Ll]ib64
[Ll]ocal
[Ss]cripts
pyvenv.cfg
pip-selfcheck.json

### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db

# Dump file
*.stackdump

# Folder config file
[Dd]esktop.ini

# Recycle Bin used on file shares
$RECYCLE.BIN/

# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp

# Windows shortcuts
*.lnk

# End of https://www.gitignore.io/api/vim,macos,python,windows,virtualenv

datasets/
.idea/
Binary file added Attention is all you need.pdf
Binary file not shown.
3,002 changes: 3,002 additions & 0 deletions BLEU/WMT14-newstest2013en_de_with_label_smoothing.txt

Large diffs are not rendered by default.

2,739 changes: 2,739 additions & 0 deletions BLEU/WMT14-newstest2014en_de_with_label_smoothing.txt

Large diffs are not rendered by default.

2,171 changes: 2,171 additions & 0 deletions BLEU/WMT14-newstest2015en_de_with_label_smoothing.txt

Large diffs are not rendered by default.

112 changes: 110 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,110 @@
# transformer-tensorflow2.0
transformer in tensorflow 2.0
# Transformer-tensorflow2.0

[attention is all you need](https://arxiv.org/pdf/1706.03762.pdf) (transformer) in tensorflow 2.0

[paper review(pdf)](https://github.com/strutive07/transformer-tensorflow2.0/blob/master/Attention%20is%20all%20you%20need.pdf)

[colab guide](https://colab.research.google.com/github/strutive07/transformer-tensorflow2.0/blob/master/transformer_implement_tf2_0.ipynb)

[Download pre-trained model(checkpoint)](https://drive.google.com/file/d/1jsY7WMI9EU5ifhcxV_sMpK8znPA1mvkf/view?usp=sharing)

[Download pre-trained bpe data](https://drive.google.com/drive/folders/1YUABrVUz3oGKgGfMJNWQl0WCP_nVjhiS?usp=sharing)

[![DeepSource](https://static.deepsource.io/deepsource-badge-light-mini.svg)](https://deepsource.io/gh/strutive07/transformer-tensorflow2.0/?ref=repository-badge)

## How to train

1. Install enviornments

bash ubuntu16_04_cuda10_cudnn7_tensorflow2.0_install.sh

2. Training

- Single GPU training
1. Change hyper parameter in train.py
2. Run training script

```bash
python train.py
```



- Multi GPU training
1. Change hyper parameter in distributed_train.py
2. Run training script

```bash
python distributed_train.py
```

3. Test
- if you did not train bpe, train bpe model or download pre-trained bpe model. LINK: [Download pre-trained bpe data](https://drive.google.com/drive/folders/1YUABrVUz3oGKgGfMJNWQl0WCP_nVjhiS?usp=sharing). You should save it in *top dataset directory*.
example: ./dataset/train.en.segmented.vocab and so on.


## How to add dataset

Add data config to `data_loader.py`

```python
CONFIG = {
'wmt14/en-de': {
'source_lang': 'en',
'target_lang': 'de',
'base_url': 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/',
'train_files': ['train.en', 'train.de'],
'vocab_files': ['vocab.50K.en', 'vocab.50K.de'],
'dictionary_files': ['dict.en-de'],
'test_files': [
'newstest2012.en', 'newstest2012.de',
'newstest2013.en', 'newstest2013.de',
'newstest2014.en', 'newstest2014.de',
'newstest2015.en', 'newstest2015.de',
]
}
}
```

If you want to add custom dataset, add data config like below and add `custom_dataset` parameter to DataLoader.load

```python
CONFIG = {
'wmt14/en-de': {
'source_lang': 'en',
'target_lang': 'de',
'train_files': ['train.en', 'train.de'],
'vocab_files': ['vocab.50K.en', 'vocab.50K.de'],
'dictionary_files': ['dict.en-de'],
'test_files': [
'newstest2012.en', 'newstest2012.de',
'newstest2013.en', 'newstest2013.de',
'newstest2014.en', 'newstest2014.de',
'newstest2015.en', 'newstest2015.de',
]
}
}

data_loader = DataLoader(
dataset_name='wmt14/en-de',
data_dir='./datasets',
batch_size=GLOBAL_BATCH_SIZE,
bpe_vocab_size=BPE_VOCAB_SIZE,
seq_max_len_source=SEQ_MAX_LEN_SOURCE,
seq_max_len_target=SEQ_MAX_LEN_TARGET,
data_limit=DATA_LIMIT,
train_ratio=TRAIN_RATIO
)

dataset, val_dataset = data_loader.load(custom_dataset=True)
```



## BLEU Score

| Test Dataset | BLEU Score |
| ------------ | ---------- |
| newstest2013 | 23.3 |
| newstest2014 | 22.85 |
| newstest2015 | 25.33 |
Loading