@@ -17,7 +17,54 @@ def _safe_import_all_objects():
1717
1818 return all_objects , None
1919 except Exception as e : # pragma: no cover - defensive
20- return None , e
20+ # fallback to manual discovery if registry fails (e.g., skbase not available)
21+ return _manual_model_discovery (), e
22+
23+
24+ def _manual_model_discovery ():
25+ """Fallback model discovery when registry is not available."""
26+ import importlib
27+ import inspect
28+ from pathlib import Path
29+
30+ models = []
31+ models_dir = Path (__file__ ).parent .parent .parent .parent / "pytorch_forecasting" / "models"
32+
33+ # Known model packages
34+ model_packages = [
35+ "deepar" , "dlinear" , "mlp" , "nbeats" , "nhits" ,
36+ "rnn" , "temporal_fusion_transformer" , "tide" , "timexer" , "xlstm"
37+ ]
38+
39+ for pkg_name in model_packages :
40+ try :
41+ module_name = f"pytorch_forecasting.models.{ pkg_name } "
42+ module = importlib .import_module (module_name )
43+
44+ # Look for model classes in the module
45+ for name , obj in inspect .getmembers (module ):
46+ if (inspect .isclass (obj ) and
47+ hasattr (obj , '__module__' ) and
48+ obj .__module__ .startswith (module_name ) and
49+ not name .startswith ('_' )):
50+
51+ # Determine estimator type based on package name
52+ if pkg_name in ["deepar" , "dlinear" , "mlp" , "nbeats" , "nhits" , "rnn" ]:
53+ estimator_type = "forecaster_v1"
54+ else :
55+ estimator_type = "forecaster_v2"
56+
57+ models .append ({
58+ 'names' : name ,
59+ 'objects' : obj ,
60+ 'object_type' : estimator_type ,
61+ 'authors' : getattr (obj , 'authors' , ['pytorch-forecasting developers' ]),
62+ 'python_dependencies' : getattr (obj , 'python_dependencies' , [])
63+ })
64+ except Exception :
65+ continue
66+
67+ return models
2168
2269
2370def _render_lines () -> list [str ]:
@@ -42,29 +89,36 @@ def _render_lines() -> list[str]:
4289 )
4390 return lines
4491
45- try :
46- df = all_objects (
47- object_types = ["forecaster_pytorch_v1" , "forecaster_pytorch_v2" ],
48- as_dataframe = True ,
49- return_tags = [
50- "object_type" ,
51- "info:name" ,
52- "authors" ,
53- "python_dependencies" ,
54- ],
55- return_names = True ,
56- )
57- except Exception as e : # pragma: no cover - defensive
58- lines .extend (
59- [
60- ".. note::" ,
61- f" Registry query failed: ``{ e } ``" ,
62- "" ,
63- ]
64- )
65- return lines
66-
67- if df is None or len (df ) == 0 :
92+ # Handle both registry DataFrame and manual discovery list
93+ if isinstance (all_objects , list ):
94+ # Manual discovery fallback
95+ models_data = all_objects
96+ else :
97+ # Registry DataFrame
98+ try :
99+ df = all_objects (
100+ object_types = ["forecaster_pytorch_v1" , "forecaster_pytorch_v2" ],
101+ as_dataframe = True ,
102+ return_tags = [
103+ "object_type" ,
104+ "info:name" ,
105+ "authors" ,
106+ "python_dependencies" ,
107+ ],
108+ return_names = True ,
109+ )
110+ models_data = df .to_dict ('records' ) if df is not None else []
111+ except Exception as e : # pragma: no cover - defensive
112+ lines .extend (
113+ [
114+ ".. note::" ,
115+ f" Registry query failed: ``{ e } ``" ,
116+ "" ,
117+ ]
118+ )
119+ return lines
120+
121+ if not models_data :
68122 lines .extend ([".. note::" , " No models found in registry." , "" ])
69123 return lines
70124
@@ -83,7 +137,10 @@ def _render_lines() -> list[str]:
83137 lines .append (" * - " + "\n - " .join (header_cols ))
84138
85139 # rows
86- for _ , row in df .sort_values ("names" ).iterrows ():
140+ # Sort models by name
141+ sorted_models = sorted (models_data , key = lambda x : x .get ("names" , "" ))
142+
143+ for row in sorted_models :
87144 pkg_cls = row ["objects" ]
88145 try :
89146 model_cls = pkg_cls .get_model_cls ()
@@ -141,7 +198,13 @@ def _is_safe_mode() -> bool:
141198 return False
142199
143200
144- def _write_models_rst (app ) -> None :
201+ def _write_models_rst (app , config = None ) -> None :
202+ """Write models.rst file.
203+
204+ This can be called either from:
205+ - config-inited event (app, config)
206+ - builder-inited event (app)
207+ """
145208 # confdir is docs/source
146209 out_file = os .path .join (app .confdir , "models.rst" )
147210 try :
@@ -169,14 +232,19 @@ def _write_models_rst(app) -> None:
169232 os .makedirs (os .path .dirname (out_file ), exist_ok = True )
170233 with open (out_file , "w" , encoding = "utf-8" ) as f :
171234 f .write ("\n " .join (lines ))
235+
236+ # Print confirmation for debugging
237+ print (f"[model_overview] Wrote { len (lines )} lines to { out_file } " )
172238
173239
174240def setup (app ):
175- # generate as early as possible so Sphinx
176- # sees the written file during source discovery
177- app .connect ("config-inited" , _write_models_rst )
241+ """Setup the Sphinx extension."""
242+ # Use builder-inited instead of config-inited
243+ # This ensures the extension is fully loaded
244+ app .connect ("builder-inited" , lambda app : _write_models_rst (app ))
245+
178246 return {
179247 "version" : "1.0" ,
180248 "parallel_read_safe" : True ,
181249 "parallel_write_safe" : True ,
182- }
250+ }
0 commit comments