Skip to content

Commit 1280603

Browse files
committed
feat: update discover functions to return fastapi app
1 parent 05f034a commit 1280603

File tree

1 file changed

+48
-12
lines changed

1 file changed

+48
-12
lines changed

src/fastapi_cli/discover.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from logging import getLogger
55
from pathlib import Path
6-
from typing import List, Union
6+
from typing import List, Tuple, Union
77

88
from fastapi_cli.exceptions import FastAPICLIException
99

@@ -45,27 +45,34 @@ class ModuleData:
4545
def get_module_data_from_path(path: Path) -> ModuleData:
4646
use_path = path.resolve()
4747
module_path = use_path
48+
4849
if use_path.is_file() and use_path.stem == "__init__":
4950
module_path = use_path.parent
51+
5052
module_paths = [module_path]
5153
extra_sys_path = module_path.parent
54+
5255
for parent in module_path.parents:
5356
init_path = parent / "__init__.py"
57+
5458
if init_path.is_file():
5559
module_paths.insert(0, parent)
5660
extra_sys_path = parent.parent
5761
else:
5862
break
5963

6064
module_str = ".".join(p.stem for p in module_paths)
65+
6166
return ModuleData(
6267
module_import_str=module_str,
6368
extra_sys_path=extra_sys_path.resolve(),
6469
module_paths=module_paths,
6570
)
6671

6772

68-
def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) -> str:
73+
def get_app_infos(
74+
*, mod_data: ModuleData, app_name: Union[str, None] = None
75+
) -> Tuple[str, str | None, str | None, str | None]:
6976
try:
7077
mod = importlib.import_module(mod_data.module_import_str)
7178
except (ImportError, ValueError) as e:
@@ -74,32 +81,41 @@ def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) ->
7481
"Ensure all the package directories have an [blue]__init__.py[/blue] file"
7582
)
7683
raise
84+
7785
if not FastAPI: # type: ignore[truthy-function]
7886
raise FastAPICLIException(
7987
"Could not import FastAPI, try running 'pip install fastapi'"
8088
) from None
89+
8190
object_names = dir(mod)
8291
object_names_set = set(object_names)
92+
8393
if app_name:
8494
if app_name not in object_names_set:
8595
raise FastAPICLIException(
8696
f"Could not find app name {app_name} in {mod_data.module_import_str}"
8797
)
98+
8899
app = getattr(mod, app_name)
100+
89101
if not isinstance(app, FastAPI):
90102
raise FastAPICLIException(
91103
f"The app name {app_name} in {mod_data.module_import_str} doesn't seem to be a FastAPI app"
92104
)
93-
return app_name
105+
106+
return app_name, app.openapi_url, app.docs_url, app.redoc_url
107+
94108
for preferred_name in ["app", "api"]:
95109
if preferred_name in object_names_set:
96110
obj = getattr(mod, preferred_name)
97111
if isinstance(obj, FastAPI):
98-
return preferred_name
112+
return preferred_name, obj.openapi_url, obj.docs_url, obj.redoc_url
113+
99114
for name in object_names:
100115
obj = getattr(mod, name)
101116
if isinstance(obj, FastAPI):
102-
return name
117+
return name, obj.openapi_url, obj.docs_url, obj.redoc_url
118+
103119
raise FastAPICLIException("Could not find FastAPI app in module, try using --app")
104120

105121

@@ -108,6 +124,9 @@ class ImportData:
108124
app_name: str
109125
module_data: ModuleData
110126
import_string: str
127+
openapi_url: str | None = None
128+
docs_url: str | None = None
129+
redoc_url: str | None = None
111130

112131

113132
def get_import_data(
@@ -121,14 +140,22 @@ def get_import_data(
121140

122141
if not path.exists():
123142
raise FastAPICLIException(f"Path does not exist {path}")
143+
124144
mod_data = get_module_data_from_path(path)
125145
sys.path.insert(0, str(mod_data.extra_sys_path))
126-
use_app_name = get_app_name(mod_data=mod_data, app_name=app_name)
146+
use_app_name, openapi_url, docs_url, redoc_url = get_app_infos(
147+
mod_data=mod_data, app_name=app_name
148+
)
127149

128150
import_string = f"{mod_data.module_import_str}:{use_app_name}"
129151

130152
return ImportData(
131-
app_name=use_app_name, module_data=mod_data, import_string=import_string
153+
app_name=use_app_name,
154+
module_data=mod_data,
155+
import_string=import_string,
156+
openapi_url=openapi_url,
157+
docs_url=docs_url,
158+
redoc_url=redoc_url,
132159
)
133160

134161

@@ -144,12 +171,21 @@ def get_import_data_from_import_string(import_string: str) -> ImportData:
144171

145172
sys.path.insert(0, str(here))
146173

174+
module_data = ModuleData(
175+
module_import_str=module_str,
176+
extra_sys_path=here,
177+
module_paths=[],
178+
)
179+
180+
_, openapi_url, docs_url, redoc_url = get_app_infos(
181+
mod_data=module_data, app_name=app_name
182+
)
183+
147184
return ImportData(
148185
app_name=app_name,
149-
module_data=ModuleData(
150-
module_import_str=module_str,
151-
extra_sys_path=here,
152-
module_paths=[],
153-
),
186+
module_data=module_data,
154187
import_string=import_string,
188+
openapi_url=openapi_url,
189+
docs_url=docs_url,
190+
redoc_url=redoc_url,
155191
)

0 commit comments

Comments
 (0)