Skip to content

Commit 40365df

Browse files
committed
Add proper types and type usages
1 parent f6c9699 commit 40365df

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from collections import defaultdict, namedtuple
2-
from typing import Any, NamedTuple, Optional, Tuple
1+
from collections import defaultdict
2+
from typing import Any, Dict, NamedTuple, Optional, Tuple
33

44
from pandas import Series
55

@@ -33,7 +33,7 @@ class RelationshipProjection(NamedTuple):
3333
properties: Optional[list[RelationshipProperty]] = None
3434

3535

36-
class MatchPart(NamedTuple):
36+
class MatchParts(NamedTuple):
3737
match: str = ""
3838
source_where: str = ""
3939
optional_match: str = ""
@@ -113,15 +113,15 @@ def project(
113113
Additional configuration for the projection.
114114
"""
115115

116-
query_params = {"graph_name": graph_name}
116+
query_params: Dict[str, Any] = {"graph_name": graph_name}
117117

118-
data_config = {}
118+
data_config: Dict[str, Any] = {}
119119
data_config_is_static = True
120120

121121
nodes = self._node_projections_spec(nodes)
122122
rels = self._rel_projections_spec(relationships)
123123

124-
match_part = MatchPart()
124+
match_parts = MatchParts()
125125
match_pattern = MatchPattern(
126126
left_arrow="<-" if inverse else "-",
127127
right_arrow="-" if inverse else "->",
@@ -141,11 +141,11 @@ def project(
141141
source_labels_filter = " OR ".join(f"source:{spec.source_label}" for spec in nodes)
142142
target_labels_filter = " OR ".join(f"target:{spec.source_label}" for spec in nodes)
143143
if allow_disconnected_nodes:
144-
match_part = match_part._replace(
144+
match_parts = match_parts._replace(
145145
source_where=f"WHERE {source_labels_filter}", optional_where=f"WHERE {target_labels_filter}"
146146
)
147147
else:
148-
match_part = match_part._replace(
148+
match_parts = match_parts._replace(
149149
source_where=f"WHERE ({source_labels_filter}) AND ({target_labels_filter})"
150150
)
151151

@@ -177,13 +177,13 @@ def project(
177177

178178
source = f"(source{match_pattern.label_filter})"
179179
if allow_disconnected_nodes:
180-
match_part = match_part._replace(
180+
match_parts = match_parts._replace(
181181
match=f"MATCH {source}", optional_match=f"OPTIONAL MATCH (source){match_pattern}"
182182
)
183183
else:
184-
match_part = match_part._replace(match=f"MATCH {source}{match_pattern}")
184+
match_parts = match_parts._replace(match=f"MATCH {source}{match_pattern}")
185185

186-
match_part = str(match_part)
186+
match_part = str(match_parts)
187187

188188
print("nodes", nodes)
189189
print("labels", label_mappings)
@@ -196,8 +196,8 @@ def project(
196196
case_part.append("CASE")
197197

198198
for label, mappings in label_mappings.items():
199-
mappings = ", ".join(f".{key.property_key}" for key in mappings)
200-
when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mappings}}}]"
199+
mapping_projection = ", ".join(f".{key.property_key}" for key in mappings)
200+
when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mapping_projection}}}]"
201201
case_part.append(when_part)
202202

203203
case_part.append(f"END AS {kind}NodeProperties")
@@ -223,12 +223,10 @@ def project(
223223

224224
query = "\n".join(part for part in [match_part, *case_part, return_part] if part)
225225

226-
result = self._query_runner.run_query_with_logging(
227-
query,
228-
query_params,
229-
).squeeze()
226+
result = self._query_runner.run_query_with_logging(query, query_params)
227+
result = result.squeeze()
230228

231-
return Graph(graph_name, self._query_runner, self._server_version), result
229+
return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore
232230

233231
def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
234232
if spec is None or spec is False:
@@ -305,10 +303,10 @@ def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> Relatio
305303

306304
raise TypeError(f"Invalid relationship projection specification: {spec}")
307305

308-
def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]:
306+
def _rel_properties_spec(self, properties: Dict[str, Any]) -> list[RelationshipProperty]:
309307
raise TypeError(f"Invalid relationship projection specification: {properties}")
310308

311-
def _render_map(self, mapping: dict[str, Any]) -> str:
309+
def _render_map(self, mapping: Dict[str, Any]) -> str:
312310
return "{" + ", ".join(f"{key}: {value}" for key, value in mapping.items()) + "}"
313311

314312
#

0 commit comments

Comments
 (0)