@@ -132,11 +132,152 @@ def get_items(self, names):
132132 return new_items
133133
134134
135+ def _make_estimator_overview (app ):
136+ """Make estimator/model overview table.
137+
138+ This function generates a dynamic table of all models in pytorch-forecasting
139+ by querying the registry system. The table is written as HTML and JSON files
140+ for inclusion in the documentation.
141+ """
142+ try :
143+ import pandas as pd
144+ from pytorch_forecasting ._registry import all_objects
145+
146+ # Base classes to exclude from the overview
147+ BASE_CLASSES = {
148+ "BaseModel" ,
149+ "BaseModelWithCovariates" ,
150+ "AutoRegressiveBaseModel" ,
151+ "AutoRegressiveBaseModelWithCovariates" ,
152+ "_BaseObject" ,
153+ "_BasePtForecaster" ,
154+ "_BasePtForecasterV2" ,
155+ "_BasePtForecaster_Common" ,
156+ }
157+
158+ # Get all objects from registry
159+ all_objs = all_objects (return_names = True , suppress_import_stdout = True )
160+
161+ records = []
162+
163+ for obj_name , obj_class in all_objs :
164+ # Skip base classes
165+ if obj_name in BASE_CLASSES :
166+ continue
167+
168+ # Skip if it's not a model class (check if it has get_class_tag method)
169+ if not hasattr (obj_class , "get_class_tag" ):
170+ continue
171+
172+ try :
173+ # Get model name from tags or use class name
174+ model_name = obj_class .get_class_tag ("info:name" , obj_name )
175+
176+ # Get authors
177+ authors = obj_class .get_class_tag ("authors" , None )
178+ if authors is None :
179+ authors = "pytorch-forecasting developers"
180+ elif isinstance (authors , list ):
181+ authors = ", " .join (authors )
182+
183+ # Get object type
184+ object_type = obj_class .get_class_tag ("object_type" , "model" )
185+ if isinstance (object_type , list ):
186+ object_type = ", " .join (object_type )
187+
188+ # Get capabilities
189+ has_exogenous = obj_class .get_class_tag ("capability:exogenous" , False )
190+ has_multivariate = obj_class .get_class_tag ("capability:multivariate" , False )
191+ has_pred_int = obj_class .get_class_tag ("capability:pred_int" , False )
192+ has_flexible_history = obj_class .get_class_tag ("capability:flexible_history_length" , False )
193+ has_cold_start = obj_class .get_class_tag ("capability:cold_start" , False )
194+
195+ # Get compute requirement
196+ compute = obj_class .get_class_tag ("info:compute" , None )
197+
198+ # Get module path for documentation link
199+ module_path = obj_class .__module__
200+ class_name = obj_class .__name__
201+
202+ # Construct documentation link
203+ # Convert module path to API documentation path
204+ api_path = module_path .replace ("." , "/" )
205+ doc_link = f"api/{ api_path } .html#{ module_path } .{ class_name } "
206+
207+ # Create model name with link
208+ model_name_link = f'<a href="{ doc_link } ">{ model_name } </a>'
209+
210+ # Build capabilities string
211+ capabilities = []
212+ if has_exogenous :
213+ capabilities .append ("Covariates" )
214+ if has_multivariate :
215+ capabilities .append ("Multiple targets" )
216+ if has_pred_int :
217+ capabilities .append ("Uncertainty" )
218+ if has_flexible_history :
219+ capabilities .append ("Flexible history" )
220+ if has_cold_start :
221+ capabilities .append ("Cold-start" )
222+
223+ capabilities_str = ", " .join (capabilities ) if capabilities else ""
224+
225+ records .append ({
226+ "Model Name" : model_name_link ,
227+ "Type" : object_type ,
228+ "Authors" : authors ,
229+ "Covariates" : "✓" if has_exogenous else "" ,
230+ "Multiple targets" : "✓" if has_multivariate else "" ,
231+ "Uncertainty" : "✓" if has_pred_int else "" ,
232+ "Flexible history" : "✓" if has_flexible_history else "" ,
233+ "Cold-start" : "✓" if has_cold_start else "" ,
234+ "Compute" : str (compute ) if compute is not None else "" ,
235+ "Capabilities" : capabilities_str ,
236+ "Module" : module_path ,
237+ })
238+ except Exception as e :
239+ # Skip objects that can't be processed
240+ print (f"Warning: Could not process { obj_name } : { e } " )
241+ continue
242+
243+ if not records :
244+ print ("Warning: No models found in registry" )
245+ return
246+
247+ # Create DataFrame
248+ df = pd .DataFrame (records )
249+
250+ # Ensure _static directory exists
251+ static_dir = SOURCE_PATH .joinpath ("_static" )
252+ static_dir .mkdir (exist_ok = True )
253+
254+ # Write HTML table
255+ html_file = static_dir .joinpath ("model_overview_table.html" )
256+ html_content = df [["Model Name" , "Type" , "Covariates" , "Multiple targets" ,
257+ "Uncertainty" , "Flexible history" , "Cold-start" , "Compute" ]].to_html (
258+ classes = "model-overview-table" , index = False , border = 0 , escape = False
259+ )
260+ html_file .write_text (html_content , encoding = "utf-8" )
261+ print (f"Generated model overview table: { html_file } " )
262+
263+ # Write JSON database for interactive filtering (optional)
264+ json_file = static_dir .joinpath ("model_overview_db.json" )
265+ df .to_json (json_file , orient = "records" , indent = 2 )
266+ print (f"Generated model overview JSON: { json_file } " )
267+
268+ except ImportError as e :
269+ print (f"Warning: Could not generate model overview (missing dependency): { e } " )
270+ except Exception as e :
271+ print (f"Warning: Error generating model overview: { e } " )
272+
273+
135274def setup (app : Sphinx ):
136275 app .add_css_file ("custom.css" )
137276 app .connect ("autodoc-skip-member" , skip )
138277 app .add_directive ("moduleautosummary" , ModuleAutoSummary )
139278 app .add_js_file ("https://buttons.github.io/buttons.js" , ** {"async" : "async" })
279+ # Connect model overview generator to builder-inited event
280+ app .connect ("builder-inited" , _make_estimator_overview )
140281
141282
142283# extension configuration
0 commit comments