Skip to content

Commit 136c1fc

Browse files
committed
Fix tests and format code
1 parent c3656ef commit 136c1fc

File tree

1 file changed

+54
-34
lines changed

1 file changed

+54
-34
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -238,61 +238,81 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
238238
"""
239239
input_hf_id = hf_id
240240

241-
hub_url = hub_url or constants.ENDPOINT
242-
hub_url_no_proto = re.sub(r"^https?://", "", hub_url).rstrip("/")
241+
# Get the hub_url (with or without protocol)
242+
full_hub_url = hub_url if hub_url is not None else constants.ENDPOINT
243+
hub_url_without_protocol = re.sub(r"https?://", "", full_hub_url)
243244

244-
hf_id_no_proto = re.sub(r"^https?://", "", hf_id)
245-
246-
is_hf_url = hf_id_no_proto.startswith(hub_url_no_proto) and "@" not in hf_id
247-
248-
if is_hf_url:
249-
hf_id = hf_id_no_proto[len(hub_url_no_proto) :].lstrip("/")
245+
# Check if hf_id is a URL containing the hub_url (check both with and without protocol)
246+
hf_id_without_protocol = re.sub(r"https?://", "", hf_id)
247+
is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id
250248

251249
HFFS_PREFIX = "hf://"
252250
if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists
253251
hf_id = hf_id[len(HFFS_PREFIX) :]
254252

255-
url_segments = [s for s in hf_id.split("/") if s]
256-
seg_len = len(url_segments)
253+
# If it's a URL, strip the endpoint prefix to get the path
254+
if is_hf_url:
255+
# Remove protocol if present
256+
hf_id_normalized = re.sub(r"https?://", "", hf_id)
257257

258-
repo_type: Optional[str] = None
259-
namespace: Optional[str] = None
260-
repo_id: str
258+
# Remove the hub_url prefix to get the relative path
259+
if hf_id_normalized.startswith(hub_url_without_protocol):
260+
# Strip the hub URL and any leading slashes
261+
hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/")
261262

263+
url_segments = hf_id.split("/")
264+
is_hf_id = len(url_segments) <= 3
265+
266+
namespace: Optional[str]
262267
if is_hf_url:
263-
if seg_len == 1:
264-
repo_id = url_segments[0]
265-
namespace = None
266-
repo_type = None
267-
elif seg_len == 2:
268-
namespace, repo_id = url_segments
269-
repo_type = None
270-
else:
271-
namespace, repo_id = url_segments[-2:]
272-
repo_type = url_segments[-3] if seg_len >= 3 else None
268+
# For URLs, we need to extract repo_type, namespace, repo_id
269+
# Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id
270+
271+
if len(url_segments) >= 3:
272+
# Check if first segment is a repo type
273+
if url_segments[0] in constants.REPO_TYPES_MAPPING:
274+
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
275+
namespace = url_segments[1]
276+
repo_id = url_segments[2]
277+
else:
278+
# First segment is namespace
279+
namespace = url_segments[0]
280+
repo_id = url_segments[1]
281+
repo_type = None
282+
elif len(url_segments) == 2:
283+
namespace = url_segments[0]
284+
repo_id = url_segments[1]
285+
286+
# Check if namespace is actually a repo type mapping
273287
if namespace in constants.REPO_TYPES_MAPPING:
274-
# canonical dataset/model
288+
# Mean canonical dataset or model
275289
repo_type = constants.REPO_TYPES_MAPPING[namespace]
276290
namespace = None
277-
278-
elif seg_len <= 3:
279-
if seg_len == 3:
291+
else:
292+
repo_type = None
293+
else:
294+
# Single segment
295+
repo_id = url_segments[0]
296+
namespace = None
297+
repo_type = None
298+
elif is_hf_id:
299+
if len(url_segments) == 3:
280300
# Passed <repo_type>/<user>/<model_id> or <repo_type>/<org>/<model_id>
281-
repo_type, namespace, repo_id = url_segments
282-
elif seg_len == 2:
301+
repo_type, namespace, repo_id = url_segments[-3:]
302+
elif len(url_segments) == 2:
283303
if url_segments[0] in constants.REPO_TYPES_MAPPING:
284304
# Passed '<model_id>' or 'datasets/<dataset_id>' for a canonical model or dataset
285305
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
286306
namespace = None
287-
repo_id = url_segments[1]
307+
repo_id = hf_id.split("/")[-1]
288308
else:
289309
# Passed <user>/<model_id> or <org>/<model_id>
290-
namespace, repo_id = url_segments
310+
namespace, repo_id = hf_id.split("/")[-2:]
291311
repo_type = None
292312
else:
313+
# Passed <model_id>
293314
repo_id = url_segments[0]
294-
namespace = None
295-
repo_type = None
315+
namespace, repo_type = None, None
296316
else:
297317
raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}")
298318

@@ -301,7 +321,7 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
301321
repo_type = constants.REPO_TYPES_MAPPING[repo_type]
302322
if repo_type == "":
303323
repo_type = None
304-
if repo_type not in constants.REPO_TYPES and repo_type is not None:
324+
if repo_type not in constants.REPO_TYPES:
305325
raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')")
306326

307327
return repo_type, namespace, repo_id

0 commit comments

Comments
 (0)