|
21 | 21 | import shutil |
22 | 22 | import tempfile |
23 | 23 | from collections import namedtuple |
24 | | -from typing import Optional, Union, Dict |
| 24 | +from typing import List, Optional, Union, Dict |
25 | 25 | from packaging import version |
26 | 26 |
|
27 | 27 | import sagemaker.image_uris |
| 28 | +from sagemaker.instance_group import InstanceGroup |
28 | 29 | from sagemaker.s3_utils import s3_path_join |
29 | 30 | from sagemaker.session_settings import SessionSettings |
30 | 31 | import sagemaker.utils |
@@ -828,14 +829,14 @@ def _validate_smdataparallel_args( |
828 | 829 |
|
829 | 830 |
|
830 | 831 | def validate_distribution( |
831 | | - distribution, |
832 | | - instance_groups, |
833 | | - framework_name, |
834 | | - framework_version, |
835 | | - py_version, |
836 | | - image_uri, |
837 | | - kwargs, |
838 | | -): |
| 832 | + distribution: Dict, |
| 833 | + instance_groups: List[InstanceGroup], |
| 834 | + framework_name: str, |
| 835 | + framework_version: str, |
| 836 | + py_version: str, |
| 837 | + image_uri: str, |
| 838 | + kwargs: Dict, |
| 839 | +) -> Dict: |
839 | 840 | """Check if distribution strategy is correctly invoked by the user. |
840 | 841 |
|
841 | 842 | Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up. |
@@ -872,7 +873,9 @@ def validate_distribution( |
872 | 873 | strategy-specific inputs are incorrect/unsupported or |
873 | 874 | heterogeneous cluster set up is incorrect |
874 | 875 | """ |
875 | | - train_instance_groups = distribution.get("instance_groups", []) |
| 876 | + validated_distribution = dict(distribution) |
| 877 | + |
| 878 | + train_instance_groups = validated_distribution.get("instance_groups", []) |
876 | 879 | if instance_groups is None: |
877 | 880 | if len(train_instance_groups) >= 1: |
878 | 881 | # if estimator's instance_groups is not defined but |
@@ -902,77 +905,77 @@ def validate_distribution( |
902 | 905 | instance_type = train_instance_group.instance_type |
903 | 906 | validate_distribution_for_instance_type( |
904 | 907 | instance_type=instance_type, |
905 | | - distribution=distribution, |
| 908 | + distribution=validated_distribution, |
906 | 909 | ) |
907 | 910 | validate_smdistributed( |
908 | 911 | instance_type=instance_type, |
909 | 912 | framework_name=framework_name, |
910 | 913 | framework_version=framework_version, |
911 | 914 | py_version=py_version, |
912 | | - distribution=distribution, |
| 915 | + distribution=validated_distribution, |
913 | 916 | image_uri=image_uri, |
914 | 917 | ) |
915 | 918 | if framework_name and framework_name == "pytorch": |
916 | 919 | # We need to validate only for PyTorch framework |
917 | 920 | validate_pytorch_distribution( |
918 | | - distribution=distribution, |
| 921 | + distribution=validated_distribution, |
919 | 922 | framework_name=framework_name, |
920 | 923 | framework_version=framework_version, |
921 | 924 | py_version=py_version, |
922 | 925 | image_uri=image_uri, |
923 | 926 | ) |
924 | 927 | validate_torch_distributed_distribution( |
925 | 928 | instance_type=instance_type, |
926 | | - distribution=distribution, |
| 929 | + distribution=validated_distribution, |
927 | 930 | framework_version=framework_version, |
928 | 931 | py_version=py_version, |
929 | 932 | image_uri=image_uri, |
930 | 933 | entry_point=kwargs["entry_point"], |
931 | 934 | ) |
932 | 935 | warn_if_parameter_server_with_multi_gpu( |
933 | | - training_instance_type=instance_type, distribution=distribution |
| 936 | + training_instance_type=instance_type, distribution=validated_distribution |
934 | 937 | ) |
935 | 938 | # get instance group names |
936 | 939 | instance_group_names.append(train_instance_group.instance_group_name) |
937 | | - distribution["instance_groups"] = instance_group_names |
| 940 | + validated_distribution["instance_groups"] = instance_group_names |
938 | 941 | else: |
939 | 942 | # in this case, we are handling a normal training job (without heterogeneous cluster) |
940 | 943 | instance_type = renamed_kwargs( |
941 | 944 | "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs |
942 | 945 | ) |
943 | 946 | validate_distribution_for_instance_type( |
944 | 947 | instance_type=instance_type, |
945 | | - distribution=distribution, |
| 948 | + distribution=validated_distribution, |
946 | 949 | ) |
947 | 950 | validate_smdistributed( |
948 | 951 | instance_type=instance_type, |
949 | 952 | framework_name=framework_name, |
950 | 953 | framework_version=framework_version, |
951 | 954 | py_version=py_version, |
952 | | - distribution=distribution, |
| 955 | + distribution=validated_distribution, |
953 | 956 | image_uri=image_uri, |
954 | 957 | ) |
955 | 958 | if framework_name and framework_name == "pytorch": |
956 | 959 | # We need to validate only for PyTorch framework |
957 | 960 | validate_pytorch_distribution( |
958 | | - distribution=distribution, |
| 961 | + distribution=validated_distribution, |
959 | 962 | framework_name=framework_name, |
960 | 963 | framework_version=framework_version, |
961 | 964 | py_version=py_version, |
962 | 965 | image_uri=image_uri, |
963 | 966 | ) |
964 | 967 | validate_torch_distributed_distribution( |
965 | 968 | instance_type=instance_type, |
966 | | - distribution=distribution, |
| 969 | + distribution=validated_distribution, |
967 | 970 | framework_version=framework_version, |
968 | 971 | py_version=py_version, |
969 | 972 | image_uri=image_uri, |
970 | 973 | entry_point=kwargs["entry_point"], |
971 | 974 | ) |
972 | 975 | warn_if_parameter_server_with_multi_gpu( |
973 | | - training_instance_type=instance_type, distribution=distribution |
| 976 | + training_instance_type=instance_type, distribution=validated_distribution |
974 | 977 | ) |
975 | | - return distribution |
| 978 | + return validated_distribution |
976 | 979 |
|
977 | 980 |
|
978 | 981 | def validate_distribution_for_instance_type(instance_type, distribution): |
|
0 commit comments