|
1 | 1 | from collections import defaultdict |
| 2 | +from functools import cmp_to_key |
2 | 3 | from typing import Dict, List, cast |
3 | 4 |
|
4 | 5 | from ...type import ( |
5 | 6 | GraphQLAbstractType, |
6 | | - GraphQLSchema, |
7 | 7 | GraphQLOutputType, |
| 8 | + GraphQLSchema, |
8 | 9 | is_abstract_type, |
9 | 10 | is_interface_type, |
10 | 11 | is_object_type, |
@@ -62,34 +63,51 @@ def get_suggested_type_names( |
62 | 63 |
|
63 | 64 | Go through all of the implementations of type, as well as the interfaces |
64 | 65 | that they implement. If any of those types include the provided field, |
65 | | - suggest them, sorted by how often the type is referenced, starting with |
66 | | - Interfaces. |
| 66 | + suggest them, sorted by how often the type is referenced. |
67 | 67 | """ |
68 | | - if is_abstract_type(type_): |
69 | | - type_ = cast(GraphQLAbstractType, type_) |
70 | | - suggested_object_types = [] |
71 | | - interface_usage_count: Dict[str, int] = defaultdict(int) |
72 | | - for possible_type in schema.get_possible_types(type_): |
73 | | - if field_name not in possible_type.fields: |
| 68 | + if not is_abstract_type(type_): |
| 69 | + # Must be an Object type, which does not have possible fields. |
| 70 | + return [] |
| 71 | + |
| 72 | + type_ = cast(GraphQLAbstractType, type_) |
| 73 | + suggested_types = set() |
| 74 | + usage_count: Dict[str, int] = defaultdict(int) |
| 75 | + for possible_type in schema.get_possible_types(type_): |
| 76 | + if field_name not in possible_type.fields: |
| 77 | + continue |
| 78 | + |
| 79 | + # This object type defines this field. |
| 80 | + suggested_types.add(possible_type) |
| 81 | + usage_count[possible_type.name] = 1 |
| 82 | + |
| 83 | + for possible_interface in possible_type.interfaces: |
| 84 | + if field_name not in possible_interface.fields: |
74 | 85 | continue |
75 | | - # This object type defines this field. |
76 | | - suggested_object_types.append(possible_type.name) |
77 | | - for possible_interface in possible_type.interfaces: |
78 | | - if field_name not in possible_interface.fields: |
79 | | - continue |
80 | | - # This interface type defines this field. |
81 | | - interface_usage_count[possible_interface.name] += 1 |
82 | | - |
83 | | - # Suggest interface types based on how common they are. |
84 | | - suggested_interface_types = sorted( |
85 | | - interface_usage_count, key=lambda k: -interface_usage_count[k] |
86 | | - ) |
87 | 86 |
|
88 | | - # Suggest both interface and object types. |
89 | | - return suggested_interface_types + suggested_object_types |
| 87 | + # This interface type defines this field. |
| 88 | + suggested_types.add(possible_interface) |
| 89 | + usage_count[possible_interface.name] += 1 |
90 | 90 |
|
91 | | - # Otherwise, must be an Object type, which does not have possible fields. |
92 | | - return [] |
| 91 | + def cmp(type_a, type_b) -> int: |
| 92 | + # Suggest both interface and object types based on how common they are. |
| 93 | + usage_count_diff = usage_count[type_b.name] - usage_count[type_a.name] |
| 94 | + if usage_count_diff: |
| 95 | + return usage_count_diff |
| 96 | + |
| 97 | + # Suggest super types first followed by subtypes |
| 98 | + if is_abstract_type(type_a) and schema.is_sub_type(type_a, type_b): |
| 99 | + return -1 |
| 100 | + if is_abstract_type(type_b) and schema.is_sub_type(type_b, type_a): |
| 101 | + return 1 |
| 102 | + |
| 103 | + if type_a.name > type_b.name: |
| 104 | + return 1 |
| 105 | + elif type_a.name < type_b.name: |
| 106 | + return -1 |
| 107 | + |
| 108 | + return 0 |
| 109 | + |
| 110 | + return [type_.name for type_ in sorted(suggested_types, key=cmp_to_key(cmp))] |
93 | 111 |
|
94 | 112 |
|
95 | 113 | def get_suggested_field_names(type_: GraphQLOutputType, field_name: str) -> List[str]: |
|
0 commit comments