@@ -238,61 +238,81 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
238238 """
239239 input_hf_id = hf_id
240240
241- hub_url = hub_url or constants .ENDPOINT
242- hub_url_no_proto = re .sub (r"^https?://" , "" , hub_url ).rstrip ("/" )
241+ # Get the hub_url (with or without protocol)
242+ full_hub_url = hub_url if hub_url is not None else constants .ENDPOINT
243+ hub_url_without_protocol = re .sub (r"https?://" , "" , full_hub_url )
243244
244- hf_id_no_proto = re .sub (r"^https?://" , "" , hf_id )
245-
246- is_hf_url = hf_id_no_proto .startswith (hub_url_no_proto ) and "@" not in hf_id
247-
248- if is_hf_url :
249- hf_id = hf_id_no_proto [len (hub_url_no_proto ) :].lstrip ("/" )
245+ # Check if hf_id is a URL containing the hub_url (check both with and without protocol)
246+ hf_id_without_protocol = re .sub (r"https?://" , "" , hf_id )
247+ is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id
250248
251249 HFFS_PREFIX = "hf://"
252250 if hf_id .startswith (HFFS_PREFIX ): # Remove "hf://" prefix if exists
253251 hf_id = hf_id [len (HFFS_PREFIX ) :]
254252
255- url_segments = [s for s in hf_id .split ("/" ) if s ]
256- seg_len = len (url_segments )
253+ # If it's a URL, strip the endpoint prefix to get the path
254+ if is_hf_url :
255+ # Remove protocol if present
256+ hf_id_normalized = re .sub (r"https?://" , "" , hf_id )
257257
258- repo_type : Optional [str ] = None
259- namespace : Optional [str ] = None
260- repo_id : str
258+ # Remove the hub_url prefix to get the relative path
259+ if hf_id_normalized .startswith (hub_url_without_protocol ):
260+ # Strip the hub URL and any leading slashes
261+ hf_id = hf_id_normalized [len (hub_url_without_protocol ) :].lstrip ("/" )
261262
263+ url_segments = hf_id .split ("/" )
264+ is_hf_id = len (url_segments ) <= 3
265+
266+ namespace : Optional [str ]
262267 if is_hf_url :
263- if seg_len == 1 :
264- repo_id = url_segments [0 ]
265- namespace = None
266- repo_type = None
267- elif seg_len == 2 :
268- namespace , repo_id = url_segments
269- repo_type = None
270- else :
271- namespace , repo_id = url_segments [- 2 :]
272- repo_type = url_segments [- 3 ] if seg_len >= 3 else None
268+ # For URLs, we need to extract repo_type, namespace, repo_id
269+ # Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id
270+
271+ if len (url_segments ) >= 3 :
272+ # Check if first segment is a repo type
273+ if url_segments [0 ] in constants .REPO_TYPES_MAPPING :
274+ repo_type = constants .REPO_TYPES_MAPPING [url_segments [0 ]]
275+ namespace = url_segments [1 ]
276+ repo_id = url_segments [2 ]
277+ else :
278+ # First segment is namespace
279+ namespace = url_segments [0 ]
280+ repo_id = url_segments [1 ]
281+ repo_type = None
282+ elif len (url_segments ) == 2 :
283+ namespace = url_segments [0 ]
284+ repo_id = url_segments [1 ]
285+
286+ # Check if namespace is actually a repo type mapping
273287 if namespace in constants .REPO_TYPES_MAPPING :
274- # canonical dataset/ model
288+ # Mean canonical dataset or model
275289 repo_type = constants .REPO_TYPES_MAPPING [namespace ]
276290 namespace = None
277-
278- elif seg_len <= 3 :
279- if seg_len == 3 :
291+ else :
292+ repo_type = None
293+ else :
294+ # Single segment
295+ repo_id = url_segments [0 ]
296+ namespace = None
297+ repo_type = None
298+ elif is_hf_id :
299+ if len (url_segments ) == 3 :
280300 # Passed <repo_type>/<user>/<model_id> or <repo_type>/<org>/<model_id>
281- repo_type , namespace , repo_id = url_segments
282- elif seg_len == 2 :
301+ repo_type , namespace , repo_id = url_segments [ - 3 :]
302+ elif len ( url_segments ) == 2 :
283303 if url_segments [0 ] in constants .REPO_TYPES_MAPPING :
284304 # Passed '<model_id>' or 'datasets/<dataset_id>' for a canonical model or dataset
285305 repo_type = constants .REPO_TYPES_MAPPING [url_segments [0 ]]
286306 namespace = None
287- repo_id = url_segments [ 1 ]
307+ repo_id = hf_id . split ( "/" )[ - 1 ]
288308 else :
289309 # Passed <user>/<model_id> or <org>/<model_id>
290- namespace , repo_id = url_segments
310+ namespace , repo_id = hf_id . split ( "/" )[ - 2 :]
291311 repo_type = None
292312 else :
313+ # Passed <model_id>
293314 repo_id = url_segments [0 ]
294- namespace = None
295- repo_type = None
315+ namespace , repo_type = None , None
296316 else :
297317 raise ValueError (f"Unable to retrieve user and repo ID from the passed HF ID: { hf_id } " )
298318
@@ -301,7 +321,7 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
301321 repo_type = constants .REPO_TYPES_MAPPING [repo_type ]
302322 if repo_type == "" :
303323 repo_type = None
304- if repo_type not in constants .REPO_TYPES and repo_type is not None :
324+ if repo_type not in constants .REPO_TYPES :
305325 raise ValueError (f"Unknown `repo_type`: '{ repo_type } ' ('{ input_hf_id } ')" )
306326
307327 return repo_type , namespace , repo_id
0 commit comments