Skip to content

Commit 350772f

Browse files
committed
修改读取测试数据的位置
1 parent fc8f8f6 commit 350772f

25 files changed

+130
-130
lines changed

tests/core/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def test_apply2(self):
228228
def split_sent(ins):
229229
return ins['raw_sentence'].split()
230230
csv_loader = CSVLoader(headers=['raw_sentence', 'label'], sep='\t')
231-
data_bundle = csv_loader.load('test/data_for_tests/tutorial_sample_dataset.csv')
231+
data_bundle = csv_loader.load('tests/data_for_tests/tutorial_sample_dataset.csv')
232232
dataset = data_bundle.datasets['train']
233233
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True)
234234
dataset.apply(split_sent, new_field_name='words', is_input=True)

tests/core/test_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ class TestCache(unittest.TestCase):
120120
def test_cache_save(self):
121121
try:
122122
start_time = time.time()
123-
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
124-
'test/data_for_tests/cws_train')
123+
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
124+
'tests/data_for_tests/cws_train')
125125
end_time = time.time()
126126
pre_time = end_time - start_time
127127
with open('test/demo1.pkl', 'rb') as f:
@@ -130,8 +130,8 @@ def test_cache_save(self):
130130
for i in range(embed.shape[0]):
131131
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
132132
start_time = time.time()
133-
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
134-
'test/data_for_tests/cws_train')
133+
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
134+
'tests/data_for_tests/cws_train')
135135
end_time = time.time()
136136
read_time = end_time - start_time
137137
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
@@ -142,7 +142,7 @@ def test_cache_save(self):
142142
def test_cache_save_overwrite_path(self):
143143
try:
144144
start_time = time.time()
145-
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', 'test/data_for_tests/cws_train',
145+
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', 'tests/data_for_tests/cws_train',
146146
_cache_fp='test/demo_overwrite.pkl')
147147
end_time = time.time()
148148
pre_time = end_time - start_time
@@ -152,8 +152,8 @@ def test_cache_save_overwrite_path(self):
152152
for i in range(embed.shape[0]):
153153
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
154154
start_time = time.time()
155-
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
156-
'test/data_for_tests/cws_train',
155+
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
156+
'tests/data_for_tests/cws_train',
157157
_cache_fp='test/demo_overwrite.pkl')
158158
end_time = time.time()
159159
read_time = end_time - start_time
@@ -165,8 +165,8 @@ def test_cache_save_overwrite_path(self):
165165
def test_cache_refresh(self):
166166
try:
167167
start_time = time.time()
168-
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
169-
'test/data_for_tests/cws_train',
168+
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
169+
'tests/data_for_tests/cws_train',
170170
_refresh=True)
171171
end_time = time.time()
172172
pre_time = end_time - start_time
@@ -176,8 +176,8 @@ def test_cache_refresh(self):
176176
for i in range(embed.shape[0]):
177177
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
178178
start_time = time.time()
179-
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
180-
'test/data_for_tests/cws_train',
179+
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
180+
'tests/data_for_tests/cws_train',
181181
_refresh=True)
182182
end_time = time.time()
183183
read_time = end_time - start_time

tests/embeddings/test_bert_embedding.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,22 @@ def test_word_drop(self):
3232
class TestBertEmbedding(unittest.TestCase):
3333
def test_bert_embedding_1(self):
3434
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split())
35-
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
35+
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1)
3636
requires_grad = embed.requires_grad
3737
embed.requires_grad = not requires_grad
3838
embed.train()
3939
words = torch.LongTensor([[2, 3, 4, 0]])
4040
result = embed(words)
4141
self.assertEqual(result.size(), (1, 4, 16))
4242

43-
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
43+
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1)
4444
embed.eval()
4545
words = torch.LongTensor([[2, 3, 4, 0]])
4646
result = embed(words)
4747
self.assertEqual(result.size(), (1, 4, 16))
4848

4949
# 自动截断而不报错
50-
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1,
50+
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1,
5151
auto_truncate=True)
5252

5353
words = torch.LongTensor([[2, 3, 4, 1]*10,
@@ -60,7 +60,7 @@ def test_save_load(self):
6060
try:
6161
os.makedirs(bert_save_test, exist_ok=True)
6262
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split())
63-
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1,
63+
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1,
6464
auto_truncate=True)
6565

6666
embed.save(bert_save_test)
@@ -76,15 +76,15 @@ def test_save_load(self):
7676

7777
class TestBertWordPieceEncoder(unittest.TestCase):
7878
def test_bert_word_piece_encoder(self):
79-
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
79+
embed = BertWordPieceEncoder(model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1)
8080
ds = DataSet({'words': ["this is a test . [SEP]".split()]})
8181
embed.index_datasets(ds, field_name='words')
8282
self.assertTrue(ds.has_field('word_pieces'))
8383
result = embed(torch.LongTensor([[1,2,3,4]]))
8484

8585
def test_bert_embed_eq_bert_piece_encoder(self):
8686
ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]})
87-
encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert')
87+
encoder = BertWordPieceEncoder(model_dir_or_name='tests/data_for_tests/embedding/small_bert')
8888
encoder.eval()
8989
encoder.index_datasets(ds, field_name='words')
9090
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
@@ -95,7 +95,7 @@ def test_bert_embed_eq_bert_piece_encoder(self):
9595
vocab.index_dataset(ds, field_name='words', new_field_name='words')
9696
ds.set_input('words')
9797
words = torch.LongTensor(ds['words'].get([0, 1]))
98-
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
98+
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert',
9999
pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1)
100100
embed.eval()
101101
words_res = embed(words)
@@ -109,7 +109,7 @@ def test_save_load(self):
109109
bert_save_test = 'bert_save_test'
110110
try:
111111
os.makedirs(bert_save_test, exist_ok=True)
112-
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.0,
112+
embed = BertWordPieceEncoder(model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.0,
113113
layers='-2')
114114
ds = DataSet({'words': ["this is a test . [SEP]".split()]})
115115
embed.index_datasets(ds, field_name='words')

tests/embeddings/test_elmo_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_download_small(self):
2121
class TestRunElmo(unittest.TestCase):
2222
def test_elmo_embedding(self):
2323
vocab = Vocabulary().add_word_lst("This is a test .".split())
24-
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1')
24+
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_elmo', layers='0,1')
2525
words = torch.LongTensor([[0, 1, 2]])
2626
hidden = elmo_embed(words)
2727
print(hidden.size())
@@ -30,7 +30,7 @@ def test_elmo_embedding(self):
3030
def test_elmo_embedding_layer_assertion(self):
3131
vocab = Vocabulary().add_word_lst("This is a test .".split())
3232
try:
33-
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo',
33+
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_elmo',
3434
layers='0,1,2')
3535
except AssertionError as e:
3636
print(e)

tests/embeddings/test_gpt2_embedding.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_download(self):
2121
print(embed(words).size())
2222

2323
def test_gpt2_embedding(self):
24-
weight_path = 'test/data_for_tests/embedding/small_gpt2'
24+
weight_path = 'tests/data_for_tests/embedding/small_gpt2'
2525
vocab = Vocabulary().add_word_lst("this is a texta sentence".split())
2626
embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1)
2727
requires_grad = embed.requires_grad
@@ -49,7 +49,7 @@ def test_gpt2_embedding(self):
4949
def test_gpt2_ebembedding_2(self):
5050
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作
5151
Embedding = GPT2Embedding
52-
weight_path = 'test/data_for_tests/embedding/small_gpt2'
52+
weight_path = 'tests/data_for_tests/embedding/small_gpt2'
5353
vocab = Vocabulary().add_word_lst("this is a texta and".split())
5454
embed1 = Embedding(vocab, model_dir_or_name=weight_path,layers=list(range(3)),
5555
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1)
@@ -89,13 +89,13 @@ def test_gpt2_ebembedding_2(self):
8989
def test_gpt2_tokenizer(self):
9090
from fastNLP.modules.tokenizer import GPT2Tokenizer
9191

92-
tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2')
92+
tokenizer = GPT2Tokenizer.from_pretrained('tests/data_for_tests/embedding/small_gpt2')
9393
print(tokenizer.encode("this is a texta a sentence"))
9494
print(tokenizer.encode('this is'))
9595

9696
def test_gpt2_embed_eq_gpt2_piece_encoder(self):
9797
# 主要检查一下embedding的结果与wordpieceencoder的结果是否一致
98-
weight_path = 'test/data_for_tests/embedding/small_gpt2'
98+
weight_path = 'tests/data_for_tests/embedding/small_gpt2'
9999
ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]})
100100
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path)
101101
encoder.eval()
@@ -187,7 +187,7 @@ def test_generate_small_gpt2(self):
187187

188188
print(used_pairs)
189189
import json
190-
with open('test/data_for_tests/embedding/small_gpt2/vocab.json', 'w') as f:
190+
with open('tests/data_for_tests/embedding/small_gpt2/vocab.json', 'w') as f:
191191
new_used_vocab = {}
192192
for idx, key in enumerate(used_vocab.keys()):
193193
new_used_vocab[key] = len(new_used_vocab)
@@ -201,12 +201,12 @@ def test_generate_small_gpt2(self):
201201

202202
json.dump(new_used_vocab, f)
203203

204-
with open('test/data_for_tests/embedding/small_gpt2/merges.txt', 'w') as f:
204+
with open('tests/data_for_tests/embedding/small_gpt2/merges.txt', 'w') as f:
205205
f.write('#version: small\n')
206206
for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])):
207207
f.write('{} {}\n'.format(k[0], k[1]))
208208

209-
new_tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2')
209+
new_tokenizer = GPT2Tokenizer.from_pretrained('tests/data_for_tests/embedding/small_gpt2')
210210
new_all_tokens = []
211211
for sent in [sent1, sent2, sent3]:
212212
tokens = new_tokenizer.tokenize(sent, add_prefix_space=True)
@@ -227,21 +227,21 @@ def test_generate_small_gpt2(self):
227227
"n_positions": 20,
228228
"vocab_size": len(new_used_vocab)
229229
}
230-
with open('test/data_for_tests/embedding/small_gpt2/config.json', 'w') as f:
230+
with open('tests/data_for_tests/embedding/small_gpt2/config.json', 'w') as f:
231231
json.dump(config, f)
232232

233233
# 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现
234234
from fastNLP.modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Config
235235

236-
config = GPT2Config.from_pretrained('test/data_for_tests/embedding/small_gpt2')
236+
config = GPT2Config.from_pretrained('tests/data_for_tests/embedding/small_gpt2')
237237

238238
model = GPT2LMHeadModel(config)
239-
torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin')
239+
torch.save(model.state_dict(), 'tests/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin')
240240
print(model(torch.LongTensor([[0,1,2,3]])))
241241

242242
def test_gpt2_word_piece_encoder(self):
243243
# 主要检查可以运行
244-
weight_path = 'test/data_for_tests/embedding/small_gpt2'
244+
weight_path = 'tests/data_for_tests/embedding/small_gpt2'
245245
ds = DataSet({'words': ["this is a test sentence".split()]})
246246
embed = GPT2WordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1)
247247
embed.index_datasets(ds, field_name='words')
@@ -256,7 +256,7 @@ def test_gpt2_word_piece_encoder(self):
256256

257257
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
258258
def test_generate(self):
259-
# weight_path = 'test/data_for_tests/embedding/small_gpt2'
259+
# weight_path = 'tests/data_for_tests/embedding/small_gpt2'
260260
weight_path = 'en'
261261

262262
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True)

0 commit comments

Comments
 (0)