Skip to content

Commit 3229124

Browse files
author
Yue Wang
committed
resolve hard-code path issue & improve readme
1 parent 71cc7ba commit 3229124

File tree

6 files changed

+24
-23
lines changed

6 files changed

+24
-23
lines changed

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,19 @@ We encourage users of this software to tell us about the applications in which t
9191

9292
## Download
9393
* [Pre-trained checkpoints & Fine-tuning data](https://console.cloud.google.com/storage/browser/sfr-codet5-data-research)
94+
* Fine-tuned checkpoints (TBA)
95+
* Extra C/C# pre-training data (TBA)
9496

9597
Instructions to download:
9698
```
9799
pip install gsutil
98100
101+
gsutil -m cp -r "gs://sfr-codet5-data-research/data/" .
102+
103+
mkdir pretrained_models; cd pretrained_models
99104
gsutil -m cp -r \
100-
"gs://sfr-codet5-data-research/data/" \
101-
"gs://sfr-codet5-data-research/pretrained_models/" \
105+
"gs://sfr-codet5-data-research/pretrained_models/codet5_small" \
106+
"gs://sfr-codet5-data-research/pretrained_models/codet5_base" \
102107
.
103108
```
104109

evaluator/CodeBLEU/bleu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from collections import Counter
1717

1818
from evaluator.CodeBLEU.utils import ngrams
19-
import pdb
2019

2120

2221
def sentence_bleu(

evaluator/CodeBLEU/calc_code_bleu.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
3+
# https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/code-to-code-trans/evaluator/CodeBLEU
34

45
# -*- coding:utf-8 -*-
56
import argparse
7+
import os
68
from evaluator.CodeBLEU import bleu, weighted_ngram_match, syntax_match, dataflow_match
7-
# import evaluator.CodeBLEU.weighted_ngram_match
8-
# import evaluator.CodeBLEU.syntax_match
9-
# import evaluator.CodeBLEU.dataflow_match
109

1110

1211
def get_codebleu(refs, hyp, lang, params='0.25,0.25,0.25,0.25'):
@@ -36,7 +35,8 @@ def get_codebleu(refs, hyp, lang, params='0.25,0.25,0.25,0.25'):
3635
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)
3736

3837
# calculate weighted ngram match
39-
keywords = [x.strip() for x in open('/export/share/wang.y/workspace/CodeT5Full/finetune/evaluator/CodeBLEU/keywords/' + lang + '.txt', 'r', encoding='utf-8').readlines()]
38+
root_dir = os.path.dirname(__file__)
39+
keywords = [x.strip() for x in open(root_dir + '/keywords/' + lang + '.txt', 'r', encoding='utf-8').readlines()]
4040

4141
def make_weights(reference_tokens, key_word_list):
4242
return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens}
@@ -78,3 +78,4 @@ def make_weights(reference_tokens, key_word_list):
7878
args = parser.parse_args()
7979
code_bleu_score = get_codebleu(args.refs, args.hyp, args.lang, args.params)
8080
print('CodeBLEU score: ', code_bleu_score)
81+

evaluator/CodeBLEU/dataflow_match.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
index_to_code_token,
88
tree_to_variable_index)
99
from tree_sitter import Language, Parser
10-
import pdb
10+
import os
11+
12+
root_dir = os.path.dirname(__file__)
1113

12-
parser_path = '/export/share/wang.y/workspace/CodeT5Full/finetune/evaluator/CodeBLEU/parser'
1314
dfg_function = {
1415
'python': DFG_python,
1516
'java': DFG_java,
@@ -26,7 +27,7 @@ def calc_dataflow_match(references, candidate, lang):
2627

2728

2829
def corpus_dataflow_match(references, candidates, lang):
29-
LANGUAGE = Language('{}/my-languages.so'.format(parser_path), lang)
30+
LANGUAGE = Language(root_dir + '/parser/my-languages.so', lang)
3031
parser = Parser()
3132
parser.set_language(LANGUAGE)
3233
parser = [parser, dfg_function[lang]]

evaluator/CodeBLEU/syntax_match.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
index_to_code_token,
88
tree_to_variable_index)
99
from tree_sitter import Language, Parser
10+
import os
1011

11-
parser_path = '/export/share/wang.y/workspace/CodeT5Full/finetune/evaluator/CodeBLEU/parser'
12+
root_dir = os.path.dirname(__file__)
1213
dfg_function = {
1314
'python': DFG_python,
1415
'java': DFG_java,
@@ -25,7 +26,7 @@ def calc_syntax_match(references, candidate, lang):
2526

2627

2728
def corpus_syntax_match(references, candidates, lang):
28-
JAVA_LANGUAGE = Language('{}/my-languages.so'.format(parser_path), lang)
29+
JAVA_LANGUAGE = Language(root_dir + '/parser/my-languages.so', lang)
2930
parser = Parser()
3031
parser.set_language(JAVA_LANGUAGE)
3132
match_count = 0

run_gen.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
from tqdm import tqdm
2929
import multiprocessing
3030
import time
31-
import sys
32-
import pdb
3331

3432
from torch.utils.tensorboard import SummaryWriter
3533
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
@@ -112,7 +110,7 @@ def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag,
112110
max_length=args.max_target_length)
113111
top_preds = list(preds.cpu().numpy())
114112
pred_ids.extend(top_preds)
115-
# pdb.set_trace()
113+
116114
pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids]
117115

118116
output_fn = os.path.join(args.res_dir, "test_{}.output".format(criteria))
@@ -146,20 +144,17 @@ def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag,
146144
f1.write(gold.target.strip() + '\n')
147145
f2.write(gold.source.strip() + '\n')
148146

149-
if args.task in ['summarize']:
147+
if args.task == 'summarize':
150148
(goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn)
151149
bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2)
152150
else:
153151
bleu = round(_bleu(gold_fn, output_fn), 2)
154-
if split_tag == 'test' and args.task in ['refine', 'translate', 'concode']:
152+
if args.task == 'concode':
155153
codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, args.lang)
156-
# except:
157-
# bleu = 0.0
158-
# codebleu = 0.0
159154

160155
em = np.mean(dev_accs) * 100
161156
result = {'em': em, 'bleu': bleu}
162-
if not args.task == 'summarize' and split_tag == 'test':
157+
if args.task == 'concode':
163158
result['codebleu'] = codebleu * 100
164159

165160
logger.info("***** Eval results *****")
@@ -364,7 +359,7 @@ def main():
364359
logger.info(" " + "***** Testing *****")
365360
logger.info(" Batch size = %d", args.eval_batch_size)
366361

367-
for criteria in ['best-bleu', 'best-ppl']: # 'best-bleu', 'best-ppl', 'last'
362+
for criteria in ['best-bleu', 'best-ppl']:
368363
file = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
369364
logger.info("Reload model from {}".format(file))
370365
model.load_state_dict(torch.load(file))
@@ -386,5 +381,4 @@ def main():
386381

387382

388383
if __name__ == "__main__":
389-
# print(' '.join(sys.argv[:]))
390384
main()

0 commit comments

Comments
 (0)