44import time
55from typing import Any , Dict , Optional
66
7- import pandas as pd
7+ import pyarrow
88import requests
99from pandas import DataFrame , Series
1010
@@ -32,12 +32,13 @@ def __init__(
3232 self ._namespace = namespace
3333 self ._server_version = server_version
3434 self ._compute_cluster_web_uri = f"http://{ compute_cluster_ip } :5005"
35+ self ._compute_cluster_arrow_uri = f"grpc://{ compute_cluster_ip } :8815"
3536 self ._compute_cluster_mlflow_uri = f"http://{ compute_cluster_ip } :8080"
3637 self ._encrypted_db_password = encrypted_db_password
3738 self ._arrow_uri = arrow_uri
3839
3940 @property
40- def model (self ):
41+ def model (self ) -> "KgeRunner" :
4142 return self
4243
4344 # @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
@@ -75,7 +76,7 @@ def train(
7576 mlflow_experiment_name : Optional [str ] = None ,
7677 ) -> Series :
7778 if epochs_per_checkpoint is None :
78- epochs_per_checkpoint = max (num_epochs / 10 , 1 )
79+ epochs_per_checkpoint = max (int ( num_epochs / 10 ) , 1 )
7980 if loss_function_kwargs is None :
8081 loss_function_kwargs = dict (margin = 1.0 , adversarial_temperature = 1.0 , gamma = 20.0 )
8182 if lr_scheduler_kwargs is None :
@@ -92,7 +93,7 @@ def train(
9293 }
9394 print (algo_config )
9495
95- graph_config = {"name" : G .name ()}
96+ graph_config = {"name" : G .name (), "config_type" : "GdsGraphConfig" }
9697
9798 config = {
9899 "user_name" : "DUMMY_USER" ,
@@ -133,7 +134,6 @@ def predict(
133134 rel_types : list [str ],
134135 mlflow_experiment_name : Optional [str ] = None ,
135136 ) -> DataFrame :
136-
137137 algo_config = {
138138 "top_k" : top_k ,
139139 "node_ids" : node_ids ,
@@ -144,8 +144,10 @@ def predict(
144144 "user_name" : "DUMMY_USER" ,
145145 "task" : "KGE_PREDICT_PYG" ,
146146 "task_config" : {
147+ "graph_config" : {"config_type" : "GdsGraphConfig" , "name" : "NOGRAPH" },
147148 "modelname" : model_name ,
148149 "task_config" : algo_config ,
150+ "stream_rel_results" : True ,
149151 },
150152 "graph_arrow_uri" : self ._arrow_uri ,
151153 }
@@ -162,7 +164,7 @@ def predict(
162164
163165 self ._wait_for_job (job_id )
164166
165- return self ._stream_results (config [ "user_name" ], config [ "task_config" ][ "modelname" ] , job_id )
167+ return self ._stream_results (config , job_id )
166168
167169 @client_only_endpoint ("gds.kge.model" )
168170 def score_triplets (
@@ -171,7 +173,6 @@ def score_triplets(
171173 triplets : list [tuple [int , str , int ]],
172174 mlflow_experiment_name : Optional [str ] = None ,
173175 ) -> DataFrame :
174-
175176 algo_config = {
176177 "triplets" : triplets ,
177178 }
@@ -180,8 +181,10 @@ def score_triplets(
180181 "user_name" : "DUMMY_USER" ,
181182 "task" : "KGE_SCORE_TRIPLETS_PYG" ,
182183 "task_config" : {
184+ "graph_config" : {"config_type" : "GdsGraphConfig" , "name" : "NOGRAPH" },
183185 "modelname" : model_name ,
184186 "task_config" : algo_config ,
187+ "stream_rel_results" : True ,
185188 },
186189 "graph_arrow_uri" : self ._arrow_uri ,
187190 }
@@ -198,22 +201,20 @@ def score_triplets(
198201
199202 self ._wait_for_job (job_id )
200203
201- return self ._stream_results (config [ "user_name" ], config [ "task_config" ][ "modelname" ] , job_id )
204+ return self ._stream_results (config , job_id )
202205
203- def _stream_results (self , user_name : str , model_name : str , job_id : str ) -> DataFrame :
204- res = requests .get (
205- f"{ self ._compute_cluster_web_uri } /internal/fetch-result" ,
206- params = {"user_name" : user_name , "modelname" : model_name , "job_id" : job_id },
207- )
208- res .raise_for_status ()
206+ def _stream_results (self , config : dict , job_id : str ) -> DataFrame :
207+ client = pyarrow .flight .connect (self ._compute_cluster_arrow_uri )
209208
210- res_file_name = f"res_{ job_id } .json"
211- with open (res_file_name , mode = "wb+" ) as f :
212- f .write (res .content )
209+ if config ["task_config" ].get ("stream_rel_results" , False ):
210+ upload_descriptor = pyarrow .flight .FlightDescriptor .for_path (f"{ job_id } .relationships" )
211+ else :
212+ raise ValueError ("No results to fetch: need to set stream_rel_results or stream_graph_results to True" )
213+ flight = client .get_flight_info (upload_descriptor )
214+ reader = client .do_get (flight .endpoints [0 ].ticket )
215+ read_table = reader .read_all ()
213216
214- df = pd .read_json (res_file_name , orient = "records" , lines = True )
215- os .remove (res_file_name )
216- return df
217+ return read_table .to_pandas ()
217218
218219 def _get_metrics (self , user_name : str , model_name : str , job_id : str ) -> DataFrame :
219220 res = requests .get (
0 commit comments