@@ -632,10 +632,15 @@ def test_validate_smdataparallel_args_not_raises():
632632 (None , None , None , None , smdataparallel_disabled ),
633633 ("ml.p3.16xlarge" , "tensorflow" , "2.3.1" , "py3" , smdataparallel_enabled ),
634634 ("ml.p3.16xlarge" , "tensorflow" , "2.3.2" , "py3" , smdataparallel_enabled ),
635+ ("ml.p3.16xlarge" , "tensorflow" , "2.3" , "py3" , smdataparallel_enabled ),
635636 ("ml.p3.16xlarge" , "tensorflow" , "2.4.1" , "py3" , smdataparallel_enabled ),
637+ ("ml.p3.16xlarge" , "tensorflow" , "2.4" , "py3" , smdataparallel_enabled ),
636638 ("ml.p3.16xlarge" , "pytorch" , "1.6.0" , "py3" , smdataparallel_enabled ),
639+ ("ml.p3.16xlarge" , "pytorch" , "1.6" , "py3" , smdataparallel_enabled ),
637640 ("ml.p3.16xlarge" , "pytorch" , "1.7.1" , "py3" , smdataparallel_enabled ),
641+ ("ml.p3.16xlarge" , "pytorch" , "1.7" , "py3" , smdataparallel_enabled ),
638642 ("ml.p3.16xlarge" , "pytorch" , "1.8.0" , "py3" , smdataparallel_enabled ),
643+ ("ml.p3.16xlarge" , "pytorch" , "1.8" , "py3" , smdataparallel_enabled ),
639644 ]
640645 for instance_type , framework_name , framework_version , py_version , distribution in good_args :
641646 fw_utils ._validate_smdataparallel_args (
0 commit comments