Skip to content

Commit e82030f

Browse files
committed
Making IPUConfig class compatible with Python 3.9
Summary: Python 3.9 makes breaking changes to ast.Subscript. In for the IPUConfig class, we use the ast module for parsing type annotations. This change allows IPUConfig to support Python 3.9 by checking the Python version and selecting the appropriate behaviour. Test Plan: Existing IPUConfig tests are robust. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, zigmasb Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, zigmasb Subscribers: zigmasb Maniphest Tasks: T71807 Differential Revision: https://phabricator.sourcevertex.net/D78433
1 parent 8baad9a commit e82030f

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

tensorflow/python/ipu/config.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import os
2828
import pydoc
2929
import typing
30+
import sys
3031

3132
from tensorflow.python.eager.context import executing_eagerly
3233
from tensorflow.compiler.plugin.poplar.driver import config_pb2
@@ -38,6 +39,16 @@
3839
from tensorflow.python.framework import ops
3940
from tensorflow.python.platform import tf_logging as logging
4041

42+
if sys.version_info >= (3, 9):
43+
# Python 3.9 makes breaking changes to ast.Subscript.
44+
def _get_subscript_slice(subscript_node):
45+
return subscript_node.slice
46+
else:
47+
48+
def _get_subscript_slice(subscript_node):
49+
assert isinstance(subscript_node.slice, ast.Index)
50+
return subscript_node.slice.value
51+
4152

4253
def _annotation_to_str(node):
4354
"""
@@ -134,11 +145,11 @@ def helper(node):
134145
if isinstance(node, ast.Subscript):
135146
lhs = node.value
136147
if is_typing_module_attr(lhs):
148+
slice_value = _get_subscript_slice(node)
137149
# e.g. Union[int, str], check v for all union types
138150
if lhs.attr == "Union":
139-
assert isinstance(node.slice, ast.Index)
140-
assert isinstance(node.slice.value, ast.Tuple)
141-
types = [helper(n) for n in node.slice.value.elts]
151+
assert isinstance(slice_value, ast.Tuple)
152+
types = [helper(n) for n in slice_value.elts]
142153
type_tys = [ty for _, ty in types]
143154
if int in type_tys and any([issubclass(Enum, ty)
144155
for ty in type_tys]):
@@ -150,31 +161,29 @@ def helper(node):
150161
if lhs.attr == "Tuple":
151162
check_tuple = lambda v: isinstance(v, tuple)
152163
# single element Tuple: check the single element in v for type
153-
if isinstance(node.slice.value, ast.Name):
154-
type_fn, _ = helper(node.slice.value)
164+
if isinstance(slice_value, ast.Name):
165+
type_fn, _ = helper(slice_value)
155166
return lambda v: check_tuple(v) and len(v) == 1 and type_fn(v[
156167
0]), tuple
157168
# more than one element Tuple
158-
if isinstance(node.slice.value, ast.Tuple):
169+
if isinstance(slice_value, ast.Tuple):
159170
# e.g. Tuple[int, ...], check each element in v for the same type
160-
if len(node.slice.value.elts) > 1 and isinstance(
161-
node.slice.value.elts[1], ast.Ellipsis):
162-
type_fn, _ = helper(node.slice.value.elts[0])
171+
if len(slice_value.elts) > 1 and isinstance(
172+
slice_value.elts[1], ast.Ellipsis):
173+
type_fn, _ = helper(slice_value.elts[0])
163174
return lambda v: check_tuple(v) and all([type_fn(e)
164175
for e in v]), tuple
165176
# e.g. Tuple[int, str], pair-wise (element, type) check
166-
type_fns = [
167-
fn for fn, _ in [helper(n) for n in node.slice.value.elts]
168-
]
177+
type_fns = [fn for fn, _ in [helper(n) for n in slice_value.elts]]
169178
return lambda v: check_tuple(v) and len(v) == len(
170179
type_fns) and all([fn(e) for fn, e in zip(type_fns, v)]), tuple
171180
# e.g. List[int], check each element in v for the same type
172181
if lhs.attr == "List":
173182
assert not isinstance(
174-
node.slice.value,
183+
slice_value,
175184
ast.Tuple), "List with more than one type not allowed."
176185
check_list = lambda v: isinstance(v, list)
177-
type_fn, _ = helper(node.slice.value)
186+
type_fn, _ = helper(slice_value)
178187
return lambda v: check_list(v) and all([type_fn(e) for e in v]), list
179188
raise Exception(f"Unsupported 'typing' attribute {lhs.attr}")
180189
raise Exception(

0 commit comments

Comments
 (0)