@@ -362,59 +362,35 @@ def test_compile_with_tensorflow_neo_in_ml_inf(session):
362362 )
363363
364364
365- def test_compile_validates_framework_version (sagemaker_session ):
366- sagemaker_session .wait_for_compilation_job = Mock (
367- return_value = {
368- "CompilationJobStatus" : "Completed" ,
369- "ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
370- "InferenceImage" : None ,
371- }
372- )
365+ @pytest .mark .parametrize (
366+ "target,framework,fx_version,expected_fx_version" ,
367+ [
368+ ("ml_c4" , "pytorch" , "1.6" , "1.6" ),
369+ ("rasp3b" , "pytorch" , "1.6.1" , "1.6" ),
370+ ("amba_cv2" , "pytorch" , "1.6.1" , None ),
371+ ("ml_c4" , "tensorflow" , "1.15.1" , "1.15" ),
372+ ("ml_c4" , "tensorflow" , "2.15.1" , "2.15" ),
373+ ("ml_inf1" , "tensorflow" , "2.15.1" , "2.15" ),
374+ ("ml_inf2" , "pytorch" , "2.0" , "2.0" ),
375+ ("ml_inf2" , "pytorch" , "2.0.1" , "2.0" ),
376+ ("ml_trn1" , "pytorch" , "2.0.1" , "2.0" ),
377+ ("ml_trn1" , "tensorflow" , "2.0.1" , "2.0" ),
378+ ],
379+ )
380+ def test_compile_validates_framework_version (
381+ sagemaker_session , target , framework , fx_version , expected_fx_version
382+ ):
373383 model = _create_model (sagemaker_session )
374- model .compile (
375- target_instance_family = "ml_c4" ,
376- input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
377- output_path = "s3://output" ,
378- role = "role" ,
379- framework = "pytorch" ,
380- framework_version = "1.6.1" ,
381- job_name = "compile-model" ,
382- )
383-
384- assert model .image_uri is None
385-
386- sagemaker_session .wait_for_compilation_job = Mock (
387- return_value = {
388- "CompilationJobStatus" : "Completed" ,
389- "ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
390- "InferenceImage" : None ,
391- }
392- )
393-
394- config = model ._compilation_job_config (
395- "rasp3b" ,
396- {"data" : [1 , 3 , 1024 , 1024 ]},
397- "s3://output" ,
398- "role" ,
399- 900 ,
400- "compile-model" ,
401- "pytorch" ,
402- None ,
403- framework_version = "1.6.1" ,
404- )
405-
406- assert config ["input_model_config" ]["FrameworkVersion" ] == "1.6"
407-
408384 config = model ._compilation_job_config (
409- "amba_cv2" ,
385+ target ,
410386 {"data" : [1 , 3 , 1024 , 1024 ]},
411387 "s3://output" ,
412388 "role" ,
413389 900 ,
414390 "compile-model" ,
415- "pytorch" ,
391+ framework ,
416392 None ,
417- framework_version = "1.6.1" ,
393+ framework_version = fx_version ,
418394 )
419395
420- assert config ["input_model_config" ].get ("FrameworkVersion" , None ) is None
396+ assert config ["input_model_config" ].get ("FrameworkVersion" , None ) == expected_fx_version
0 commit comments