Skip to content

Commit 391628f

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Add a service registry to provide a generic way to register custom service implementations to be used in FastAPI server
To register a custom service: - Create a factory function that takes a URI and returns an instance of your custom service. This function will parse any details it needs from the URI. - Register your factory with the global service registry. You need to define a unique URI scheme for your service (e.g., custom). PiperOrigin-RevId: 822310466
1 parent 409df13 commit 391628f

File tree

3 files changed

+411
-67
lines changed

3 files changed

+411
-67
lines changed

src/google/adk/cli/fast_api.py

Lines changed: 18 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,16 @@
3333
from starlette.types import Lifespan
3434
from watchdog.observers import Observer
3535

36-
from ..artifacts.gcs_artifact_service import GcsArtifactService
3736
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
3837
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
3938
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
4039
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
4140
from ..memory.in_memory_memory_service import InMemoryMemoryService
42-
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
4341
from ..runners import Runner
4442
from ..sessions.in_memory_session_service import InMemorySessionService
45-
from ..sessions.vertex_ai_session_service import VertexAiSessionService
4643
from ..utils.feature_decorator import working_in_progress
4744
from .adk_web_server import AdkWebServer
45+
from .service_registry import get_service_registry
4846
from .utils import envs
4947
from .utils import evals
5048
from .utils.agent_change_handler import AgentChangeEventHandler
@@ -85,54 +83,14 @@ def get_fast_api_app(
8583
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
8684
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
8785

88-
def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name):
89-
if not agent_engine_id_or_resource_name:
90-
raise click.ClickException(
91-
"Agent engine resource name or resource id can not be empty."
92-
)
93-
94-
# "projects/my-project/locations/us-central1/reasoningEngines/1234567890",
95-
if "/" in agent_engine_id_or_resource_name:
96-
# Validate resource name.
97-
if len(agent_engine_id_or_resource_name.split("/")) != 6:
98-
raise click.ClickException(
99-
"Agent engine resource name is mal-formatted. It should be of"
100-
" format :"
101-
" projects/{project_id}/locations/{location}/reasoningEngines/{resource_id}"
102-
)
103-
project = agent_engine_id_or_resource_name.split("/")[1]
104-
location = agent_engine_id_or_resource_name.split("/")[3]
105-
agent_engine_id = agent_engine_id_or_resource_name.split("/")[-1]
106-
else:
107-
envs.load_dotenv_for_agent("", agents_dir)
108-
project = os.environ.get("GOOGLE_CLOUD_PROJECT", None)
109-
location = os.environ.get("GOOGLE_CLOUD_LOCATION", None)
110-
agent_engine_id = agent_engine_id_or_resource_name
111-
return project, location, agent_engine_id
86+
service_registry = get_service_registry()
11287

11388
# Build the Memory service
11489
if memory_service_uri:
115-
if memory_service_uri.startswith("rag://"):
116-
from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
117-
118-
rag_corpus = memory_service_uri.split("://")[1]
119-
if not rag_corpus:
120-
raise click.ClickException("Rag corpus can not be empty.")
121-
envs.load_dotenv_for_agent("", agents_dir)
122-
memory_service = VertexAiRagMemoryService(
123-
rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}'
124-
)
125-
elif memory_service_uri.startswith("agentengine://"):
126-
agent_engine_id_or_resource_name = memory_service_uri.split("://")[1]
127-
project, location, agent_engine_id = _parse_agent_engine_resource_name(
128-
agent_engine_id_or_resource_name
129-
)
130-
memory_service = VertexAiMemoryBankService(
131-
project=project,
132-
location=location,
133-
agent_engine_id=agent_engine_id,
134-
)
135-
else:
90+
memory_service = service_registry.create_memory_service(
91+
memory_service_uri, agents_dir=agents_dir
92+
)
93+
if not memory_service:
13694
raise click.ClickException(
13795
"Unsupported memory service URI: %s" % memory_service_uri
13896
)
@@ -141,34 +99,27 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name):
14199

142100
# Build the Session service
143101
if session_service_uri:
144-
if session_service_uri.startswith("agentengine://"):
145-
agent_engine_id_or_resource_name = session_service_uri.split("://")[1]
146-
project, location, agent_engine_id = _parse_agent_engine_resource_name(
147-
agent_engine_id_or_resource_name
148-
)
149-
session_service = VertexAiSessionService(
150-
project=project,
151-
location=location,
152-
agent_engine_id=agent_engine_id,
153-
)
154-
else:
102+
session_kwargs = session_db_kwargs or {}
103+
session_service = service_registry.create_session_service(
104+
session_service_uri, agents_dir=agents_dir, **session_kwargs
105+
)
106+
if not session_service:
107+
# Fallback to DatabaseSessionService if the service registry doesn't
108+
# support the session service URI scheme.
155109
from ..sessions.database_session_service import DatabaseSessionService
156110

157-
# Database session additional settings
158-
if session_db_kwargs is None:
159-
session_db_kwargs = {}
160111
session_service = DatabaseSessionService(
161-
db_url=session_service_uri, **session_db_kwargs
112+
db_url=session_service_uri, **session_kwargs
162113
)
163114
else:
164115
session_service = InMemorySessionService()
165116

166117
# Build the Artifact service
167118
if artifact_service_uri:
168-
if artifact_service_uri.startswith("gs://"):
169-
gcs_bucket = artifact_service_uri.split("://")[1]
170-
artifact_service = GcsArtifactService(bucket_name=gcs_bucket)
171-
else:
119+
artifact_service = service_registry.create_artifact_service(
120+
artifact_service_uri, agents_dir=agents_dir
121+
)
122+
if not artifact_service:
172123
raise click.ClickException(
173124
"Unsupported artifact service URI: %s" % artifact_service_uri
174125
)
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import os
18+
from typing import Any
19+
from typing import Dict
20+
from typing import Protocol
21+
from urllib.parse import urlparse
22+
23+
from ..artifacts.base_artifact_service import BaseArtifactService
24+
from ..memory.base_memory_service import BaseMemoryService
25+
from ..sessions.base_session_service import BaseSessionService
26+
27+
28+
def _load_gcp_config(
29+
agents_dir: str | None, service_name: str
30+
) -> tuple[str, str]:
31+
"""Loads GCP project and location from environment."""
32+
if not agents_dir:
33+
raise ValueError(f"agents_dir must be provided for {service_name}")
34+
35+
from .utils import envs
36+
37+
envs.load_dotenv_for_agent("", agents_dir)
38+
39+
project = os.environ.get("GOOGLE_CLOUD_PROJECT")
40+
location = os.environ.get("GOOGLE_CLOUD_LOCATION")
41+
42+
if not project or not location:
43+
raise ValueError("GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_LOCATION not set.")
44+
45+
return project, location
46+
47+
48+
def _parse_agent_engine_kwargs(
49+
uri_part: str, agents_dir: str | None
50+
) -> dict[str, Any]:
51+
"""Helper to parse agent engine resource name."""
52+
if not uri_part:
53+
raise ValueError(
54+
"Agent engine resource name or resource id can not be empty."
55+
)
56+
if "/" in uri_part:
57+
parts = uri_part.split("/")
58+
if not (
59+
len(parts) == 6
60+
and parts[0] == "projects"
61+
and parts[2] == "locations"
62+
and parts[4] == "reasoningEngines"
63+
):
64+
raise ValueError(
65+
"Agent engine resource name is mal-formatted. It should be of"
66+
" format :"
67+
" projects/{project_id}/locations/{location}/reasoningEngines/{resource_id}"
68+
)
69+
project = parts[1]
70+
location = parts[3]
71+
agent_engine_id = parts[5]
72+
else:
73+
project, location = _load_gcp_config(
74+
agents_dir, "short-form agent engine IDs"
75+
)
76+
agent_engine_id = uri_part
77+
return {
78+
"project": project,
79+
"location": location,
80+
"agent_engine_id": agent_engine_id,
81+
}
82+
83+
84+
class ServiceFactory(Protocol):
85+
"""Protocol for service factory functions."""
86+
87+
def __call__(
88+
self, uri: str, **kwargs
89+
) -> BaseSessionService | BaseArtifactService | BaseMemoryService:
90+
...
91+
92+
93+
class ServiceRegistry:
94+
"""Registry for custom service URI schemes."""
95+
96+
def __init__(self):
97+
self._session_factories: Dict[str, ServiceFactory] = {}
98+
self._artifact_factories: Dict[str, ServiceFactory] = {}
99+
self._memory_factories: Dict[str, ServiceFactory] = {}
100+
101+
def register_session_service(
102+
self, scheme: str, factory: ServiceFactory
103+
) -> None:
104+
"""Register a factory for a custom session service URI scheme.
105+
106+
Args:
107+
scheme: URI scheme (e.g., 'custom')
108+
factory: Callable that takes (uri, **kwargs) and returns
109+
BaseSessionService
110+
"""
111+
self._session_factories[scheme] = factory
112+
113+
def register_artifact_service(
114+
self, scheme: str, factory: ServiceFactory
115+
) -> None:
116+
"""Register a factory for a custom artifact service URI scheme."""
117+
self._artifact_factories[scheme] = factory
118+
119+
def register_memory_service(
120+
self, scheme: str, factory: ServiceFactory
121+
) -> None:
122+
"""Register a factory for a custom memory service URI scheme."""
123+
self._memory_factories[scheme] = factory
124+
125+
def create_session_service(
126+
self, uri: str, **kwargs
127+
) -> BaseSessionService | None:
128+
"""Create session service from URI using registered factories."""
129+
scheme = urlparse(uri).scheme
130+
if scheme and scheme in self._session_factories:
131+
return self._session_factories[scheme](uri, **kwargs)
132+
return None
133+
134+
def create_artifact_service(
135+
self, uri: str, **kwargs
136+
) -> BaseArtifactService | None:
137+
"""Create artifact service from URI using registered factories."""
138+
scheme = urlparse(uri).scheme
139+
if scheme and scheme in self._artifact_factories:
140+
return self._artifact_factories[scheme](uri, **kwargs)
141+
return None
142+
143+
def create_memory_service(
144+
self, uri: str, **kwargs
145+
) -> BaseMemoryService | None:
146+
"""Create memory service from URI using registered factories."""
147+
scheme = urlparse(uri).scheme
148+
if scheme and scheme in self._memory_factories:
149+
return self._memory_factories[scheme](uri, **kwargs)
150+
return None
151+
152+
153+
def _register_builtin_services(registry: ServiceRegistry) -> None:
154+
"""Register built-in service implementations."""
155+
156+
# -- Session Services --
157+
def agentengine_session_factory(uri: str, **kwargs):
158+
from ..sessions.vertex_ai_session_service import VertexAiSessionService
159+
160+
parsed = urlparse(uri)
161+
params = _parse_agent_engine_kwargs(
162+
parsed.netloc + parsed.path, kwargs.get("agents_dir")
163+
)
164+
return VertexAiSessionService(**params)
165+
166+
def database_session_factory(uri: str, **kwargs):
167+
from ..sessions.database_session_service import DatabaseSessionService
168+
169+
kwargs_copy = kwargs.copy()
170+
kwargs_copy.pop("agents_dir", None)
171+
return DatabaseSessionService(db_url=uri, **kwargs_copy)
172+
173+
registry.register_session_service("agentengine", agentengine_session_factory)
174+
for scheme in ["sqlite", "postgresql", "mysql"]:
175+
registry.register_session_service(scheme, database_session_factory)
176+
177+
# -- Artifact Services --
178+
def gcs_artifact_factory(uri: str, **kwargs):
179+
from ..artifacts.gcs_artifact_service import GcsArtifactService
180+
181+
kwargs_copy = kwargs.copy()
182+
kwargs_copy.pop("agents_dir", None)
183+
parsed_uri = urlparse(uri)
184+
bucket_name = parsed_uri.netloc
185+
return GcsArtifactService(bucket_name=bucket_name, **kwargs_copy)
186+
187+
registry.register_artifact_service("gs", gcs_artifact_factory)
188+
189+
# -- Memory Services --
190+
def rag_memory_factory(uri: str, **kwargs):
191+
from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
192+
193+
rag_corpus = urlparse(uri).netloc
194+
if not rag_corpus:
195+
raise ValueError("Rag corpus can not be empty.")
196+
agents_dir = kwargs.get("agents_dir")
197+
project, location = _load_gcp_config(agents_dir, "RAG memory service")
198+
return VertexAiRagMemoryService(
199+
rag_corpus=(
200+
f"projects/{project}/locations/{location}/ragCorpora/{rag_corpus}"
201+
)
202+
)
203+
204+
def agentengine_memory_factory(uri: str, **kwargs):
205+
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
206+
207+
parsed = urlparse(uri)
208+
params = _parse_agent_engine_kwargs(
209+
parsed.netloc + parsed.path, kwargs.get("agents_dir")
210+
)
211+
return VertexAiMemoryBankService(**params)
212+
213+
registry.register_memory_service("rag", rag_memory_factory)
214+
registry.register_memory_service("agentengine", agentengine_memory_factory)
215+
216+
217+
# Global registry instance
218+
_global_registry = ServiceRegistry()
219+
_register_builtin_services(_global_registry)
220+
221+
222+
def get_service_registry() -> ServiceRegistry:
223+
"""Get the global service registry instance."""
224+
return _global_registry

0 commit comments

Comments
 (0)