Skip to content

Commit 9b0c214

Browse files
Enhance repo_type_and_id_from_hf_id of hf_api (#3507)
* Enhance repo_type_and_id_from_hf_id of hf_api * Add missing line * Fix failed tests * Format code * Fix tests and format code * Apply suggestions from code review * Optimize code --------- Co-authored-by: Lucain <lucain@huggingface.co> Co-authored-by: Lucain <lucainp@gmail.com>
1 parent ee6a7d9 commit 9b0c214

File tree

2 files changed

+68
-27
lines changed

2 files changed

+68
-27
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195

196196
USERNAME_PLACEHOLDER = "hf_user"
197197
_REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$")
198+
_REGEX_HTTP_PROTOCOL = re.compile(r"https?://")
198199

199200
_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = (
200201
"\nNote: Creating a commit assumes that the repo already exists on the"
@@ -239,28 +240,62 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
239240
"""
240241
input_hf_id = hf_id
241242

242-
hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else constants.ENDPOINT)
243-
is_hf_url = hub_url in hf_id and "@" not in hf_id
243+
# Get the hub_url (with or without protocol)
244+
full_hub_url = hub_url if hub_url is not None else constants.ENDPOINT
245+
hub_url_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", full_hub_url)
246+
247+
# Check if hf_id is a URL containing the hub_url (check both with and without protocol)
248+
hf_id_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", hf_id)
249+
is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id
244250

245251
HFFS_PREFIX = "hf://"
246252
if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists
247253
hf_id = hf_id[len(HFFS_PREFIX) :]
248254

255+
# If it's a URL, strip the endpoint prefix to get the path
256+
if is_hf_url:
257+
# Remove protocol if present
258+
hf_id_normalized = _REGEX_HTTP_PROTOCOL.sub("", hf_id)
259+
260+
# Remove the hub_url prefix to get the relative path
261+
if hf_id_normalized.startswith(hub_url_without_protocol):
262+
# Strip the hub URL and any leading slashes
263+
hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/")
264+
249265
url_segments = hf_id.split("/")
250266
is_hf_id = len(url_segments) <= 3
251267

252268
namespace: Optional[str]
253269
if is_hf_url:
254-
namespace, repo_id = url_segments[-2:]
255-
if namespace == hub_url:
256-
namespace = None
257-
if len(url_segments) > 2 and hub_url not in url_segments[-3]:
258-
repo_type = url_segments[-3]
259-
elif namespace in constants.REPO_TYPES_MAPPING:
260-
# Mean canonical dataset or model
261-
repo_type = constants.REPO_TYPES_MAPPING[namespace]
262-
namespace = None
270+
# For URLs, we need to extract repo_type, namespace, repo_id
271+
# Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id
272+
273+
if len(url_segments) >= 3:
274+
# Check if first segment is a repo type
275+
if url_segments[0] in constants.REPO_TYPES_MAPPING:
276+
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
277+
namespace = url_segments[1]
278+
repo_id = url_segments[2]
279+
else:
280+
# First segment is namespace
281+
namespace = url_segments[0]
282+
repo_id = url_segments[1]
283+
repo_type = None
284+
elif len(url_segments) == 2:
285+
namespace = url_segments[0]
286+
repo_id = url_segments[1]
287+
288+
# Check if namespace is actually a repo type mapping
289+
if namespace in constants.REPO_TYPES_MAPPING:
290+
# Mean canonical dataset or model
291+
repo_type = constants.REPO_TYPES_MAPPING[namespace]
292+
namespace = None
293+
else:
294+
repo_type = None
263295
else:
296+
# Single segment
297+
repo_id = url_segments[0]
298+
namespace = None
264299
repo_type = None
265300
elif is_hf_id:
266301
if len(url_segments) == 3:

tests/test_hf_api.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,25 +2800,31 @@ def test_git_push_end_to_end(self):
28002800
class ParseHFUrlTest(unittest.TestCase):
28012801
def test_repo_type_and_id_from_hf_id_on_correct_values(self):
28022802
possible_values = {
2803-
"https://huggingface.co/id": [None, None, "id"],
2804-
"https://huggingface.co/user/id": [None, "user", "id"],
2805-
"https://huggingface.co/datasets/user/id": ["dataset", "user", "id"],
2806-
"https://huggingface.co/spaces/user/id": ["space", "user", "id"],
2807-
"user/id": [None, "user", "id"],
2808-
"dataset/user/id": ["dataset", "user", "id"],
2809-
"space/user/id": ["space", "user", "id"],
2810-
"id": [None, None, "id"],
2811-
"hf://id": [None, None, "id"],
2812-
"hf://user/id": [None, "user", "id"],
2813-
"hf://model/user/name": ["model", "user", "name"], # 's' is optional
2814-
"hf://models/user/name": ["model", "user", "name"],
2803+
"hub": {
2804+
"https://huggingface.co/id": [None, None, "id"],
2805+
"https://huggingface.co/user/id": [None, "user", "id"],
2806+
"https://huggingface.co/datasets/user/id": ["dataset", "user", "id"],
2807+
"https://huggingface.co/spaces/user/id": ["space", "user", "id"],
2808+
"user/id": [None, "user", "id"],
2809+
"dataset/user/id": ["dataset", "user", "id"],
2810+
"space/user/id": ["space", "user", "id"],
2811+
"id": [None, None, "id"],
2812+
"hf://id": [None, None, "id"],
2813+
"hf://user/id": [None, "user", "id"],
2814+
"hf://model/user/name": ["model", "user", "name"], # 's' is optional
2815+
"hf://models/user/name": ["model", "user", "name"],
2816+
},
2817+
"self-hosted": {
2818+
"http://localhost:8080/hf/user/id": [None, "user", "id"],
2819+
"http://localhost:8080/hf/datasets/user/id": ["dataset", "user", "id"],
2820+
"http://localhost:8080/hf/models/user/id": ["model", "user", "id"],
2821+
},
28152822
}
28162823

28172824
for key, value in possible_values.items():
2818-
self.assertEqual(
2819-
repo_type_and_id_from_hf_id(key, hub_url=ENDPOINT_PRODUCTION),
2820-
tuple(value),
2821-
)
2825+
hub_url = ENDPOINT_PRODUCTION if key == "hub" else "http://localhost:8080/hf"
2826+
for key, value in value.items():
2827+
assert repo_type_and_id_from_hf_id(key, hub_url=hub_url) == tuple(value)
28222828

28232829
def test_repo_type_and_id_from_hf_id_on_wrong_values(self):
28242830
for hub_id in [

0 commit comments

Comments
 (0)