Skip to content

Commit 7c2e627

Browse files
committed
fixed type hints; added seq length 512 to encoder test defaults
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent 6788aa4 commit 7c2e627

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __download_file(url, filename):
4747
print(f"An error occurred: {e}")
4848

4949
def __sample_requests(
50-
prompt_list,
50+
prompt_list: List[str],
5151
num_requests: int,
5252
tokenizer: BaseTokenizer,
5353
prompt_length_min: int = 32,

tests/models/test_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
SQUAD_V2_DATASET_PATH = os.environ.get("SQUAD_V2_DATASET_PATH", os.path.expanduser("~/squad_v2"))
1919
common_model_paths = os.environ.get("FMS_TEST_SHAPES_COMMON_MODEL_PATHS", [ROBERTA_SQUAD_V2])
2020
common_batch_sizes = os.environ.get("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1, 2, 4, 8])
21-
common_seq_lengths = os.environ.get("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64])
21+
common_seq_lengths = os.environ.get("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64, 512])
2222

2323
# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/roberta,/tmp/models/roberta-base-squad2"
2424
if isinstance(common_model_paths, str):

0 commit comments

Comments
 (0)