File tree Expand file tree Collapse file tree 1 file changed +19
-2
lines changed Expand file tree Collapse file tree 1 file changed +19
-2
lines changed Original file line number Diff line number Diff line change @@ -165,10 +165,27 @@ def matrix_shapes(draw, stack_shapes=shapes()):
165165 allow_infinity = False ))
166166
167167def mutually_broadcastable_shapes (
168- num_shapes : int , ** kw
168+ num_shapes : int ,
169+ * ,
170+ base_shape : Shape = (),
171+ min_dims : int = 0 ,
172+ max_dims : Optional [int ] = None ,
173+ min_side : int = 0 ,
174+ max_side : Optional [int ] = None ,
169175) -> SearchStrategy [Tuple [Shape , ...]]:
176+ if max_dims is None :
177+ max_dims = min (max (len (base_shape ), min_dims ) + 5 , 32 )
178+ if max_side is None :
179+ max_side = max (base_shape [- max_dims :] + (min_side ,)) + 5
170180 return (
171- xps .mutually_broadcastable_shapes (num_shapes , ** kw )
181+ xps .mutually_broadcastable_shapes (
182+ num_shapes ,
183+ base_shape = base_shape ,
184+ min_dims = min_dims ,
185+ max_dims = max_dims ,
186+ min_side = min_side ,
187+ max_side = max_side ,
188+ )
172189 .map (lambda BS : BS .input_shapes )
173190 .filter (lambda shapes : all (
174191 prod (i for i in s if i > 0 ) < MAX_ARRAY_SIZE for s in shapes
You can’t perform that action at this time.
0 commit comments