diff --git a/function_app.py b/function_app.py index b487a70..3ff8459 100644 --- a/function_app.py +++ b/function_app.py @@ -10,6 +10,7 @@ get_resources_by_batch, search_resources, ) +from shared.azure_search_client import get_search_client from shared.database import initialize_database # Initialize the function app @@ -18,8 +19,11 @@ # Initialize database connection db, collection = initialize_database() +# Initialize Azure Search client +search_client = get_search_client() + # Register functions get_resources_by_batch.register_function(app, collection) -search_resources.register_function(app, collection) +search_resources.register_function(app, search_client) get_filters.register_function(app, collection, db["filter_values"]) get_dependent_workloads.register_function(app, collection) diff --git a/functions/search_resources.py b/functions/search_resources.py index 65d7177..4641823 100644 --- a/functions/search_resources.py +++ b/functions/search_resources.py @@ -5,7 +5,6 @@ import logging import azure.functions as func -from bson import json_util from shared.utils import ( create_error_response, @@ -15,14 +14,14 @@ ) -def register_function(app, collection): +def register_function(app, search_client): """Register the function with the app.""" @app.function_name(name="search_resources") @app.route(route="resources/search", auth_level=func.AuthLevel.ANONYMOUS) def search_resources(req: func.HttpRequest) -> func.HttpResponse: """ - Search resources with filtering capabilities. + Search resources with filtering capabilities using Azure AI Search. Route: /resources/search @@ -73,13 +72,11 @@ def search_resources(req: func.HttpRequest) -> func.HttpResponse: 400, "Invalid pagination parameters" ) - # Create query object similar to the one used in the Data API query_object = { "query": contains_str, "sort": sort_param if sort_param else "default", } - # Parse filter criteria similar to MongoDB implementation if must_include: try: # Parse must-include parameter format: field1,value1,value2;field2,value1,value2 @@ -99,57 +96,100 @@ def search_resources(req: func.HttpRequest) -> func.HttpResponse: 400, "Invalid filter value format" ) - # Add to query object similar to original implementation query_object[field] = values except Exception as e: logging.error(f"Error parsing filter criteria: {str(e)}") return create_error_response(400, "Invalid filter format") - # Build the aggregation pipeline - pipeline = [] + odata_filter = build_odata_filter(query_object) - # Add search query stage if a query is provided + # Build search text with Lucene query syntax + # Mimics MongoDB Atlas Search behavior: + # - Text search across id, description, category, architecture, tags + # - Boost matches on id field + # - Word order doesn't matter (each term searched independently) if contains_str: - pipeline.extend(get_search_pipeline(query_object)) - - # Add filter pipeline stages - pipeline.extend(get_filter_pipeline(query_object)) - - # Add latest version pipeline - pipeline.extend(get_latest_version_pipeline()) + # Split into terms - Atlas Search tokenizes and matches each term + terms = contains_str.split() + + # Build query that requires all terms (like Atlas "must" clause) + # but allows them in any order and any field + term_clauses = [] + for term in terms: + escaped_term = escape_lucene_query(term) + # Search each term across all searchable fields + # Use quoted phrase for exact matching, plus unquoted for partial + # Boost id matches by 10x (like Atlas Search boost) + term_clauses.append( + f'(id:"{escaped_term}"^10 OR id:{escaped_term}*^10 OR ' + f'description:"{escaped_term}" OR description:{escaped_term}* OR ' + f'category:"{escaped_term}" OR category:{escaped_term}* OR ' + f'architecture:"{escaped_term}" OR architecture:{escaped_term}* OR ' + f'tags:"{escaped_term}" OR tags:{escaped_term}* OR ' + f'"{escaped_term}" OR {escaped_term}*)' + ) - # Add sort pipeline - pipeline.extend(get_sort_pipeline(query_object)) + # Join with AND - all terms must match (like Atlas "must" clause) + search_text = " AND ".join(term_clauses) + query_type = "full" # Lucene query syntax + else: + search_text = "*" + query_type = "simple" + + # Fetch ALL results from Azure AI Search (no server-side pagination) + # We need all results to deduplicate by id and keep only latest version + all_results = [] + + results = search_client.search( + search_text=search_text, + filter=odata_filter, + include_total_count=True, + top=1000, # Max allowed by Azure AI Search + query_type=query_type, + ) - # Add pagination - pipeline.extend(get_page_pipeline(page, page_size)) + for result in results: + # Convert to dict and add score + doc = dict(result) + doc["score"] = result.get("@search.score", 0) + all_results.append(doc) - # Execute the aggregation - results = list(collection.aggregate(pipeline)) + logging.info(f"Raw results count: {len(all_results)}") - # Process results to match expected output format - processed_results = [] - total_count = 0 + unique_resources = keep_latest_versions(all_results) - if results: - processed_results = results - total_count = results[0].get("totalCount", 0) if results else 0 + # Apply sorting + sorted_resources = apply_sorting( + unique_resources, query_object.get("sort", "default") + ) - # Remove MongoDB _id field and ensure database field is added - for resource in processed_results: - if "_id" in resource: - del resource["_id"] - resource["database"] = ( - "gem5-vision" # Add database field like in original implementation - ) + # Calculate total count before pagination + total_count = len(sorted_resources) + + # Apply pagination + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + paginated_results = sorted_resources[start_idx:end_idx] + + # Clean up results for response + for resource in paginated_results: + # Remove Azure Search internal fields + keys_to_remove = [ + "@search.score", + "@search.highlights", + "@search.reranker_score", + ] + for key in keys_to_remove: + resource.pop(key, None) + resource["database"] = "gem5-vision" response_data = { - "documents": processed_results, + "documents": paginated_results, "totalCount": total_count, } return func.HttpResponse( - body=json.dumps(response_data, default=json_util.default), + body=json.dumps(response_data), headers={"Content-Type": "application/json"}, status_code=200, ) @@ -159,317 +199,189 @@ def search_resources(req: func.HttpRequest) -> func.HttpResponse: return create_error_response(500, "Internal server error") -def get_sort(sort): - """ - Returns a MongoDB-compatible sort dictionary based on the provided sort string. - - Parameters: - - sort (str): Sort parameter. One of 'date', 'name', 'version', 'id_asc', 'id_desc'. - - Returns: - - dict: Sort specification for MongoDB $sort stage. +def escape_lucene_query(query_str): """ - switch_dict = { - "date": {"date": -1}, - "name": {"id": 1}, - "version": {"ver_latest": -1}, - "id_asc": {"id": 1}, - "id_desc": {"id": -1}, - } - return switch_dict.get(sort, {"score": -1}) + Escape special characters in a Lucene query string. + https://learn.microsoft.com/en-us/azure/search/query-lucene-syntax#escaping-special-characters + Lucene special characters: + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ / -def get_latest_version_pipeline(): - """ - Constructs an aggregation pipeline to extract the latest version of each resource. - - This stage: - - Parses semantic version strings into integer arrays for proper comparison. - - Sorts and groups by resource ID. - - Selects the latest version document per resource. + Parameters: + - query_str (str): The raw query string. Returns: - - list: List of aggregation pipeline stages. + - str: The escaped query string safe for Lucene syntax. """ - return [ - { - "$addFields": { - "resource_version_parts": { - "$map": { - "input": { - "$split": ["$resource_version", "."], - }, - "as": "item", - "in": {"$toInt": "$$item"}, - }, - }, - }, - }, - { - "$sort": { - "id": 1, - "resource_version_parts.0": -1, - "resource_version_parts.1": -1, - "resource_version_parts.2": -1, - "resource_version_parts.3": -1, - }, - }, - { - "$group": { - "_id": "$id", - "latest_version": { - "$first": "$resource_version", - }, - "document": {"$first": "$$ROOT"}, - }, - }, - { - "$replaceRoot": { - "newRoot": { - "$mergeObjects": [ - "$document", - { - "id": "$_id", - "latest_version": "$latest_version", - }, - ], - }, - }, - }, + # Characters that need escaping in Lucene (except quotes which are used for phrases) + special_chars = [ + "\\", + "+", + "-", + "&", + "|", + "!", + "(", + ")", + "{", + "}", + "[", + "]", + "^", + "~", + "*", + "?", + ":", + "/", ] + escaped = query_str + for char in special_chars: + escaped = escaped.replace(char, f"\\{char}") -def get_search_pipeline(query_object): - """ - Constructs a MongoDB Atlas Search pipeline based on the input search query. + escaped = escaped.replace('"', '\\"') - The pipeline: - - Performs fuzzy full-text search on fields like id, description, category, architecture, and tags. - - Boosts matches on the id and specific gem5_versions. - - Adds a 'score' field representing search relevance. + return escaped - Parameters: - - query_object (dict): Dictionary containing a 'query' key for the search string. - Returns: - - list: List of aggregation pipeline stages for search. +def build_odata_filter(query_object): """ - - pipeline = [ - { - "$search": { - "compound": { - "should": [ - { - "text": { - "path": "id", - "query": query_object["query"], - "score": {"boost": {"value": 10}}, - } - }, - { - "text": { - "path": "gem5_versions", - "query": "24.1", - "score": {"boost": {"value": 10}}, - } - }, - ], - "must": [ - { - "text": { - "query": query_object["query"], - "path": [ - "id", - "description", - "category", - "architecture", - "tags", - ], - "fuzzy": {"maxEdits": 2, "maxExpansions": 100}, - } - } - ], - } - } - }, - {"$addFields": {"score": {"$meta": "searchScore"}}}, - ] - - return pipeline - - -def get_filter_pipeline(query_object): - """ - Constructs a MongoDB aggregation pipeline to filter documents based on - multiple fields. - - Supported filters include: - - tags (unwound and matched individually) - - gem5_versions (unwound and matched individually) - - category (exact match) - - architecture (exact match) + Build an OData filter string for Azure AI Search based on query parameters. Parameters: - query_object (dict): Dictionary containing filter keys and values. Returns: - - list: List of aggregation pipeline stages for filtering documents. + - str or None: OData filter string, or None if no filters. """ - pipeline = [] + filters = [] + + # Filter by category + if query_object.get("category"): + values = query_object["category"] + if len(values) == 1: + filters.append(f"category eq '{values[0]}'") + else: + category_filters = " or ".join( + [f"category eq '{v}'" for v in values] + ) + filters.append(f"({category_filters})") + + # Filter by architecture + if query_object.get("architecture"): + values = query_object["architecture"] + if len(values) == 1: + filters.append(f"architecture eq '{values[0]}'") + else: + arch_filters = " or ".join( + [f"architecture eq '{v}'" for v in values] + ) + filters.append(f"({arch_filters})") - # Filter by tags + # Filter by tags (collection field) if query_object.get("tags"): - pipeline.extend( - [ - { - "$addFields": { - "tag": "$tags", - }, - }, - { - "$unwind": "$tag", - }, - { - "$match": { - "tag": { - "$in": query_object["tags"], - }, - }, - }, - { - "$group": { - "_id": "$_id", - "doc": { - "$first": "$$ROOT", - }, - }, - }, - { - "$replaceRoot": { - "newRoot": "$doc", - }, - }, - ] - ) + values = query_object["tags"] + tag_filters = " or ".join([f"tags/any(t: t eq '{v}')" for v in values]) + filters.append(f"({tag_filters})") - # Filter by gem5_versions + # Filter by gem5_versions (collection field) if query_object.get("gem5_versions"): - pipeline.extend( - [ - { - "$addFields": { - "version": "$gem5_versions", - }, - }, - { - "$unwind": "$version", - }, - { - "$match": { - "version": { - "$in": query_object["gem5_versions"], - }, - }, - }, - { - "$group": { - "_id": "$_id", - "doc": { - "$first": "$$ROOT", - }, - }, - }, - { - "$replaceRoot": { - "newRoot": "$doc", - }, - }, - ] - ) - - # Add other filters (category and architecture) - match_conditions = [] - if query_object.get("category"): - match_conditions.append( - {"category": {"$in": query_object["category"]}} + values = query_object["gem5_versions"] + version_filters = " or ".join( + [f"gem5_versions/any(v: v eq '{v}')" for v in values] ) + filters.append(f"({version_filters})") - if query_object.get("architecture"): - match_conditions.append( - {"architecture": {"$in": query_object["architecture"]}} - ) + if filters: + return " and ".join(filters) + return None - if match_conditions: - pipeline.append({"$match": {"$and": match_conditions}}) - return pipeline +def parse_version(version_str): + """ + Parse a semantic version string (x.y.z) into a tuple of integers for comparison. + Parameters: + - version_str (str): Version string like "1.2.3" -def get_sort_pipeline(query_object): + Returns: + - tuple: Tuple of integers (major, minor, patch) or (0, 0, 0) if parsing fails. """ - Constructs an aggregation pipeline to sort documents based on a sort - parameter. + try: + parts = version_str.split(".") + return tuple(int(p) for p in parts) + except (ValueError, AttributeError): + return (0, 0, 0) - Adds a field `ver_latest` to represent the maximum gem5 version, - then sorts based on the value provided in query_object["sort"]. + +def keep_latest_versions(results): + """ + Deduplicate results by resource id, keeping only the document with + the highest semantic version. Also preserves the maximum search score + across all versions for proper relevance sorting. Parameters: - - query_object (dict): Dictionary containing the 'sort' key. + - results (list): List of search result documents. Returns: - - list: List of aggregation pipeline stages for sorting. + - list: List of unique resources with only the latest version. """ - return [ - { - "$addFields": { - "ver_latest": { - "$max": {"$ifNull": ["$gem5_versions", []]}, - }, - }, - }, - { - "$sort": get_sort(query_object.get("sort")), - }, - ] + latest_by_id = {} + max_score_by_id = {} + + for doc in results: + resource_id = doc.get("id") + if not resource_id: + continue + + current_version = parse_version(doc.get("resource_version", "0.0.0")) + current_score = doc.get("score", 0) + + if resource_id not in max_score_by_id: + max_score_by_id[resource_id] = current_score + else: + max_score_by_id[resource_id] = max( + max_score_by_id[resource_id], current_score + ) + if resource_id not in latest_by_id: + latest_by_id[resource_id] = doc + else: + existing_version = parse_version( + latest_by_id[resource_id].get("resource_version", "0.0.0") + ) + if current_version > existing_version: + latest_by_id[resource_id] = doc -def get_page_pipeline(current_page, page_size): - """ - Constructs an aggregation pipeline to paginate documents. + for resource_id, doc in latest_by_id.items(): + doc["score"] = max_score_by_id.get(resource_id, 0) + + return list(latest_by_id.values()) - This stage: - - Groups all items and total count. - - Unwinds grouped documents back into individual records. - - Applies skip and limit to paginate results. + +def apply_sorting(results, sort_param): + """ + Sort results based on the sort parameter. Parameters: - - current_page (int): Current page number (1-indexed). - - page_size (int): Number of results per page. + - results (list): List of documents to sort. + - sort_param (str): Sort parameter ('date', 'name', 'version', 'id_asc', 'id_desc', or 'default'). Returns: - - list: List of aggregation pipeline stages for pagination. + - list: Sorted list of documents. """ - return [ - { - "$group": { - "_id": None, - "totalCount": {"$sum": 1}, - "items": {"$push": "$$ROOT"}, - } - }, - {"$unwind": "$items"}, - { - "$replaceRoot": { - "newRoot": { - "$mergeObjects": ["$items", {"totalCount": "$totalCount"}] - } - } - }, - { - "$skip": (current_page - 1) * page_size, - }, - { - "$limit": page_size, - }, - ] + if sort_param == "date": + return sorted(results, key=lambda x: x.get("date", ""), reverse=True) + elif sort_param == "name" or sort_param == "id_asc": + return sorted(results, key=lambda x: x.get("id", "").lower()) + elif sort_param == "id_desc": + return sorted( + results, key=lambda x: x.get("id", "").lower(), reverse=True + ) + elif sort_param == "version": + return sorted( + results, + key=lambda x: max(x.get("gem5_versions", ["0"]), default="0"), + reverse=True, + ) + else: + return sorted(results, key=lambda x: x.get("score", 0), reverse=True) diff --git a/requirements.txt b/requirements.txt index 2d24be2..50d2cca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ # Manually managing azure-functions-worker may cause unexpected issues azure-functions +azure-search-documents pymongo==4.11.2 python-dotenv requests==2.32.3 diff --git a/shared/azure_search_client.py b/shared/azure_search_client.py new file mode 100644 index 0000000..dcdebfa --- /dev/null +++ b/shared/azure_search_client.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 The Regents of the University of California +# SPDX-License-Identifier: BSD-3-Clause + +import logging +import os + +from azure.core.credentials import AzureKeyCredential +from azure.search.documents import SearchClient + + +def get_search_client(): + """ + Creates and returns an Azure AI Search client. + + Required environment variables: + - AZURE_SEARCH_ENDPOINT: The Azure AI Search service endpoint + - AZURE_SEARCH_API_KEY: The API key for authentication + - AZURE_SEARCH_INDEX_NAME: The name of the search index + + Returns: + SearchClient: An Azure AI Search client instance + """ + endpoint = os.environ.get("AZURE_SEARCH_ENDPOINT") + api_key = os.environ.get("AZURE_SEARCH_API_KEY") + index_name = os.environ.get("AZURE_SEARCH_INDEX_NAME") + + if not all([endpoint, api_key, index_name]): + logging.error("Missing Azure Search configuration") + raise ValueError( + "Missing required environment variables: " + "AZURE_SEARCH_ENDPOINT, AZURE_SEARCH_API_KEY, or AZURE_SEARCH_INDEX_NAME" + ) + + credential = AzureKeyCredential(api_key) + client = SearchClient( + endpoint=endpoint, index_name=index_name, credential=credential + ) + + return client diff --git a/shared/utils.py b/shared/utils.py index 9fd7d88..b0dbecd 100644 --- a/shared/utils.py +++ b/shared/utils.py @@ -49,10 +49,11 @@ def sanitize_contains_str(value): def sanitize_must_include(value): # Only allow field,value1,value2;field2,value1,value2, max 500 chars + # Allow dots for version numbers like "24.1" if not isinstance(value, str): return "" value = value.strip() - value = re.sub(r"[^\w,;\-]", "", value) + value = re.sub(r"[^\w,;\-\.]", "", value) return value[:500] diff --git a/tests/resources_api_unit_tests.py b/tests/resources_api_unit_tests.py index de8c27d..be44fd5 100644 --- a/tests/resources_api_unit_tests.py +++ b/tests/resources_api_unit_tests.py @@ -398,7 +398,6 @@ def test_search_multiple_gem5_versions(self): len({"22.0", "23.0"}.intersection(gem5_versions)) > 0 ) - # EDGE CASE AND STRESS TESTS def test_search_with_special_characters(self): """Test search with special characters in the search string.""" params = {"contains-str": "test-resource_with.special-chars"} @@ -428,6 +427,147 @@ def test_batch_with_maximum_resources(self): ) self.assertEqual(response.status_code, 200) + def test_search_without_contains_str(self): + """Test search without contains-str returns all resources.""" + params = {"page": 1, "page-size": 5} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("documents", data) + self.assertIn("totalCount", data) + self.assertGreater(len(data["documents"]), 0) + + def test_search_sort_by_id_asc(self): + """Test search with sort by id ascending.""" + params = {"contains-str": "arm", "sort": "id_asc", "page-size": 10} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + resources = data["documents"] + if len(resources) > 1: + ids = [r["id"].lower() for r in resources] + self.assertEqual(ids, sorted(ids)) + + def test_search_sort_by_id_desc(self): + """Test search with sort by id descending.""" + params = {"contains-str": "arm", "sort": "id_desc", "page-size": 10} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + resources = data["documents"] + if len(resources) > 1: + ids = [r["id"].lower() for r in resources] + self.assertEqual(ids, sorted(ids, reverse=True)) + + def test_search_total_count(self): + """Test that totalCount is accurate.""" + params = {"contains-str": "arm", "page": 1, "page-size": 2} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("totalCount", data) + self.assertIsInstance(data["totalCount"], int) + self.assertGreaterEqual(data["totalCount"], len(data["documents"])) + + def test_search_with_multiple_architectures(self): + """Test search with multiple architecture values.""" + params = { + "contains-str": "hello", + "must-include": "architecture,x86,ARM", + } + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + for resource in data["documents"]: + self.assertIn(resource["architecture"], ["x86", "ARM"]) + + def test_search_pagination_beyond_results(self): + """Test pagination when page is beyond available results.""" + params = { + "contains-str": "arm-hello64-static", + "page": 1000, + "page-size": 10, + } + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(len(data["documents"]), 0) # No results on this page + + def test_search_pagination_max_page_size(self): + """Test pagination with maximum page-size (100).""" + params = {"contains-str": "resource", "page": 1, "page-size": 100} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + + def test_search_pagination_exceeds_max_page_size(self): + """Test pagination with page-size exceeding max (should fail).""" + params = {"contains-str": "resource", "page": 1, "page-size": 101} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 400) + + def test_search_pagination_invalid_page(self): + """Test pagination with invalid page number (should fail).""" + params = {"contains-str": "resource", "page": 0, "page-size": 10} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 400) + + def test_search_returns_latest_version_only(self): + """Test that search returns only the latest version of each resource.""" + params = {"contains-str": "ubuntu", "page-size": 50} + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + resources = data["documents"] + + # Check no duplicate IDs (each resource appears only once) + ids = [r["id"] for r in resources] + self.assertEqual(len(ids), len(set(ids))) + + def test_search_combined_filters_sort_pagination(self): + """Test search with filters, sorting, and pagination combined.""" + params = { + "contains-str": "ubuntu", + "must-include": "architecture,ARM", + "sort": "name", + "page": 1, + "page-size": 5, + } + response = requests.get( + f"{self.base_url}/resources/search", params=params + ) + self.assertEqual(response.status_code, 200) + data = response.json() + resources = data["documents"] + + # Validate architecture filter + for resource in resources: + self.assertEqual(resource["architecture"], "ARM") + + # Validate sorting (by name/id ascending) + if len(resources) > 1: + ids = [r["id"].lower() for r in resources] + self.assertEqual(ids, sorted(ids)) + if __name__ == "__main__": unittest.main()