33from diffgram import __version__
44
55from diffgram .file .view import get_label_file_dict
6- from diffgram .core .directory import get_directory_list
7- from diffgram .core .directory import set_directory_by_name
86from diffgram .convert .convert import convert_label
97from diffgram .label .label_new import label_new
108
@@ -29,7 +27,10 @@ def __init__(
2927 client_secret = None ,
3028 debug = False ,
3129 staging = False ,
32- host = None
30+ host = None ,
31+ set_default_directory = True ,
32+ refresh_local_label_dict = True
33+
3334 ):
3435
3536 self .session = requests .Session ()
@@ -50,24 +51,41 @@ def __init__(
5051 self .host = "https://diffgram.com"
5152 else :
5253 self .host = host
53- self .directory_id = None
54- self .name_to_file_id = None
54+
5555 self .auth (
5656 project_string_id = project_string_id ,
5757 client_id = client_id ,
5858 client_secret = client_secret )
59- self .client_id = client_id
60- self .client_secret = client_secret
6159
6260 self .file = FileConstructor (self )
63- self .train = Train (self )
61+ # self.train = Train(self)
6462 self .job = Job (self )
6563 self .guide = Guide (self )
66- self .directory = Directory (self , validate_ids = False )
64+ self .directory = Directory (self ,
65+ init_file_ids = False ,
66+ validate_ids = False )
6767 self .export = Export (self )
6868 self .task = Task (client = self )
69+
70+ self .directory_id = None
71+ self .name_to_file_id = None
72+
73+
74+ if set_default_directory is True :
75+ self .set_default_directory ()
76+ print ("Default directory set:" , self .directory_id )
77+
78+ if refresh_local_label_dict is True :
79+ self .get_label_file_dict ()
80+
81+ self .client_id = client_id
82+ self .client_secret = client_secret
83+
6984 self .label_schema_list = self .get_label_schema_list ()
7085
86+ self .directory_list = []
87+
88+
7189 def get_member_list (self ):
7290 url = '/api/project/{}/view' .format (self .project_string_id )
7391 response = self .session .get (url = self .host + url )
@@ -216,9 +234,7 @@ def handle_errors(self,
216234 def auth (self ,
217235 project_string_id ,
218236 client_id = None ,
219- client_secret = None ,
220- set_default_directory = True ,
221- refresh_local_label_dict = True
237+ client_secret = None
222238 ):
223239 """
224240 Define authorization configuration
@@ -242,55 +258,65 @@ def auth(self,
242258 if client_id and client_secret :
243259 self .session .auth = (client_id , client_secret )
244260
245- if set_default_directory is True :
246- self .set_default_directory ()
247261
248- if refresh_local_label_dict is True :
249- # Refresh local labels from Diffgram project
250- self .get_label_file_dict ()
262+ def set_directory_by_name (self , name ):
263+ """
264+
265+ Arguments
266+ self
267+ name, string
268+
269+ """
270+
271+ if name is None :
272+ raise Exception ("No name provided." )
273+
274+ # Don't refresh by default, just set from existing
275+
276+ names_attempted = []
277+ did_set = False
278+
279+ if not self .directory_list :
280+ self .directory_list = self .directory .get_directory_list ()
281+
282+ for directory in self .directory_list :
283+
284+ if directory .nickname == name :
285+ self .set_default_directory (directory = directory )
286+ did_set = True
287+ break
288+ else :
289+ names_attempted .append (directory .nickname )
290+
291+ if did_set is False :
292+ raise Exception (name , " does not exist. Valid names are: " +
293+ str (names_attempted ))
294+
251295
252296 def set_default_directory (self ,
253- directory_id = None ):
297+ directory_id = None ,
298+ directory = None ):
254299 """
255300 -> If no id is provided fetch directory list for project
256301 and set first directory to default.
257302 -> Sets the headers of self.session
258303
259- Arguments
260- directory_id, int, defaults to None
261-
262- Returns
263- None
264-
265- Future
266- TODO return error if invalid directory?
267-
268304 """
269305
270306 if directory_id :
271- # TODO check if valid?
272- # data = {}
273- # data["directory_id"] = directory_id
274307 self .directory_id = directory_id
275- else :
276-
277- data = self .get_directory_list ()
278-
279- self .default_directory = data ['default_directory' ]
280-
281- # Hold over till refactoring (would prefer to
282- # just call self.directory_default.id
283- self .directory_id = self .default_directory ['id' ]
308+ if directory :
309+ self .directory_id = directory .id
310+ self .default_directory = directory
311+
312+ self .directory_list = self .directory .get_directory_list ()
284313
285- self .directory_list = data ["directory_list" ]
286314 self .session .headers .update (
287315 {'directory_id' : str (self .directory_id )})
288316
289317
290318# TODO review not using this pattern anymore
291319
292320setattr (Project , "get_label_file_dict" , get_label_file_dict )
293- setattr (Project , "get_directory_list" , get_directory_list )
294321setattr (Project , "convert_label" , convert_label )
295322setattr (Project , "label_new" , label_new )
296- setattr (Project , "set_directory_by_name" , set_directory_by_name )
0 commit comments