Skip to content

Commit f5bdb76

Browse files
committed
docs: generate models.rst and switch model_overview to builder-inited; add debug print; load extension via extensions list
.
1 parent 1d9f000 commit f5bdb76

File tree

3 files changed

+335
-40
lines changed

3 files changed

+335
-40
lines changed

docs/source/_ext/model_overview.py

Lines changed: 98 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,54 @@ def _safe_import_all_objects():
1717

1818
return all_objects, None
1919
except Exception as e: # pragma: no cover - defensive
20-
return None, e
20+
# fallback to manual discovery if registry fails (e.g., skbase not available)
21+
return _manual_model_discovery(), e
22+
23+
24+
def _manual_model_discovery():
25+
"""Fallback model discovery when registry is not available."""
26+
import importlib
27+
import inspect
28+
from pathlib import Path
29+
30+
models = []
31+
models_dir = Path(__file__).parent.parent.parent.parent / "pytorch_forecasting" / "models"
32+
33+
# Known model packages
34+
model_packages = [
35+
"deepar", "dlinear", "mlp", "nbeats", "nhits",
36+
"rnn", "temporal_fusion_transformer", "tide", "timexer", "xlstm"
37+
]
38+
39+
for pkg_name in model_packages:
40+
try:
41+
module_name = f"pytorch_forecasting.models.{pkg_name}"
42+
module = importlib.import_module(module_name)
43+
44+
# Look for model classes in the module
45+
for name, obj in inspect.getmembers(module):
46+
if (inspect.isclass(obj) and
47+
hasattr(obj, '__module__') and
48+
obj.__module__.startswith(module_name) and
49+
not name.startswith('_')):
50+
51+
# Determine estimator type based on package name
52+
if pkg_name in ["deepar", "dlinear", "mlp", "nbeats", "nhits", "rnn"]:
53+
estimator_type = "forecaster_v1"
54+
else:
55+
estimator_type = "forecaster_v2"
56+
57+
models.append({
58+
'names': name,
59+
'objects': obj,
60+
'object_type': estimator_type,
61+
'authors': getattr(obj, 'authors', ['pytorch-forecasting developers']),
62+
'python_dependencies': getattr(obj, 'python_dependencies', [])
63+
})
64+
except Exception:
65+
continue
66+
67+
return models
2168

2269

2370
def _render_lines() -> list[str]:
@@ -42,29 +89,36 @@ def _render_lines() -> list[str]:
4289
)
4390
return lines
4491

45-
try:
46-
df = all_objects(
47-
object_types=["forecaster_pytorch_v1", "forecaster_pytorch_v2"],
48-
as_dataframe=True,
49-
return_tags=[
50-
"object_type",
51-
"info:name",
52-
"authors",
53-
"python_dependencies",
54-
],
55-
return_names=True,
56-
)
57-
except Exception as e: # pragma: no cover - defensive
58-
lines.extend(
59-
[
60-
".. note::",
61-
f" Registry query failed: ``{e}``",
62-
"",
63-
]
64-
)
65-
return lines
66-
67-
if df is None or len(df) == 0:
92+
# Handle both registry DataFrame and manual discovery list
93+
if isinstance(all_objects, list):
94+
# Manual discovery fallback
95+
models_data = all_objects
96+
else:
97+
# Registry DataFrame
98+
try:
99+
df = all_objects(
100+
object_types=["forecaster_pytorch_v1", "forecaster_pytorch_v2"],
101+
as_dataframe=True,
102+
return_tags=[
103+
"object_type",
104+
"info:name",
105+
"authors",
106+
"python_dependencies",
107+
],
108+
return_names=True,
109+
)
110+
models_data = df.to_dict('records') if df is not None else []
111+
except Exception as e: # pragma: no cover - defensive
112+
lines.extend(
113+
[
114+
".. note::",
115+
f" Registry query failed: ``{e}``",
116+
"",
117+
]
118+
)
119+
return lines
120+
121+
if not models_data:
68122
lines.extend([".. note::", " No models found in registry.", ""])
69123
return lines
70124

@@ -83,7 +137,10 @@ def _render_lines() -> list[str]:
83137
lines.append(" * - " + "\n - ".join(header_cols))
84138

85139
# rows
86-
for _, row in df.sort_values("names").iterrows():
140+
# Sort models by name
141+
sorted_models = sorted(models_data, key=lambda x: x.get("names", ""))
142+
143+
for row in sorted_models:
87144
pkg_cls = row["objects"]
88145
try:
89146
model_cls = pkg_cls.get_model_cls()
@@ -141,7 +198,13 @@ def _is_safe_mode() -> bool:
141198
return False
142199

143200

144-
def _write_models_rst(app) -> None:
201+
def _write_models_rst(app, config=None) -> None:
202+
"""Write models.rst file.
203+
204+
This can be called either from:
205+
- config-inited event (app, config)
206+
- builder-inited event (app)
207+
"""
145208
# confdir is docs/source
146209
out_file = os.path.join(app.confdir, "models.rst")
147210
try:
@@ -169,14 +232,19 @@ def _write_models_rst(app) -> None:
169232
os.makedirs(os.path.dirname(out_file), exist_ok=True)
170233
with open(out_file, "w", encoding="utf-8") as f:
171234
f.write("\n".join(lines))
235+
236+
# Print confirmation for debugging
237+
print(f"[model_overview] Wrote {len(lines)} lines to {out_file}")
172238

173239

174240
def setup(app):
175-
# generate as early as possible so Sphinx
176-
# sees the written file during source discovery
177-
app.connect("config-inited", _write_models_rst)
241+
"""Setup the Sphinx extension."""
242+
# Use builder-inited instead of config-inited
243+
# This ensures the extension is fully loaded
244+
app.connect("builder-inited", lambda app: _write_models_rst(app))
245+
178246
return {
179247
"version": "1.0",
180248
"parallel_read_safe": True,
181249
"parallel_write_safe": True,
182-
}
250+
}

docs/source/conf.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import shutil
1616
import sys
1717

18+
# Add _ext directory to Python path for local extensions
19+
sys.path.append(os.path.abspath("../_ext"))
20+
1821
from sphinx.application import Sphinx
1922
from sphinx.ext.autosummary import Autosummary
2023
from sphinx.pycode import ModuleAnalyzer
@@ -56,6 +59,7 @@
5659
"sphinx.ext.viewcode",
5760
"sphinx.ext.githubpages",
5861
"sphinx.ext.napoleon",
62+
"model_overview",
5963
]
6064

6165
# Add any paths that contain templates here, relative to this directory.
@@ -147,15 +151,7 @@ def setup(app: Sphinx):
147151
app.connect("autodoc-skip-member", skip)
148152
app.add_directive("moduleautosummary", ModuleAutoSummary)
149153
app.add_js_file("https://buttons.github.io/buttons.js", **{"async": "async"})
150-
# load custom model overview generator if available
151-
try:
152-
if "model_overview" not in extensions:
153-
extensions.append("model_overview")
154-
except Exception as exc:
155-
# avoid hard-failing docs builds; make the reason visible in Sphinx logs
156-
sphinx_logging.getLogger(__name__).warning(
157-
"model_overview extension not loaded: %s", exc
158-
)
154+
# model_overview extension is loaded via extensions list
159155

160156

161157
# extension configuration

0 commit comments

Comments
 (0)