Skip to content

Commit 7304e07

Browse files
committed
fix: 1938 generate models overview from registry
Signed-off-by: Pinaka07 <anandsingh.as1996@gmail.com>
1 parent aef2d2b commit 7304e07

File tree

5 files changed

+400
-12
lines changed

5 files changed

+400
-12
lines changed

docs/source/_static/custom.css

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,81 @@ span.highlighted {
5656
.highlight > pre > .s2 {
5757
color: #647db6 !important;
5858
}
59+
60+
/* Model overview table styling */
61+
.model-overview-table {
62+
width: 100%;
63+
font-size: 0.9em;
64+
border-collapse: collapse;
65+
margin: 20px 0;
66+
}
67+
68+
.model-overview-table th {
69+
background-color: #f0f0f0;
70+
font-weight: bold;
71+
padding: 12px;
72+
text-align: left;
73+
border-bottom: 2px solid #ddd;
74+
}
75+
76+
.model-overview-table td {
77+
padding: 10px 12px;
78+
border-bottom: 1px solid #ddd;
79+
}
80+
81+
.model-overview-table tr:hover {
82+
background-color: #f9f9f9;
83+
}
84+
85+
.model-overview-table a {
86+
color: #647db6;
87+
text-decoration: none;
88+
}
89+
90+
.model-overview-table a:hover {
91+
text-decoration: underline;
92+
color: #ee4c2c;
93+
}
94+
95+
#model-overview-container {
96+
margin: 20px 0;
97+
}
98+
99+
#model-filters {
100+
margin-bottom: 15px;
101+
padding: 10px;
102+
background-color: #f9f9f9;
103+
border-radius: 4px;
104+
}
105+
106+
#model-filters label {
107+
margin-right: 15px;
108+
font-weight: 500;
109+
}
110+
111+
#model-filters select {
112+
margin-left: 5px;
113+
padding: 5px 10px;
114+
border: 1px solid #ddd;
115+
border-radius: 4px;
116+
background-color: white;
117+
}
118+
119+
/* DataTables styling overrides */
120+
#model-table_wrapper {
121+
margin-top: 20px;
122+
}
123+
124+
#model-table_wrapper .dataTables_filter input {
125+
margin-left: 10px;
126+
padding: 5px;
127+
border: 1px solid #ddd;
128+
border-radius: 4px;
129+
}
130+
131+
#model-table_wrapper .dataTables_length select {
132+
padding: 5px;
133+
border: 1px solid #ddd;
134+
border-radius: 4px;
135+
margin: 0 5px;
136+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/**
2+
* JavaScript for interactive model overview table.
3+
*
4+
* This script loads the model overview data from JSON and creates
5+
* an interactive DataTable with search and filtering capabilities.
6+
*/
7+
8+
$(document).ready(function() {
9+
// Determine the correct path to the JSON file
10+
// In built HTML, the path should be relative to the current page
11+
var jsonPath = '_static/model_overview_db.json';
12+
13+
// Load model data from JSON
14+
$.getJSON(jsonPath, function(data) {
15+
// Initialize DataTable
16+
var table = $('#model-table').DataTable({
17+
data: data,
18+
columns: [
19+
{
20+
data: 'Model Name',
21+
title: 'Model Name',
22+
render: function(data, type, row) {
23+
// If data is already HTML (from pandas), return as-is
24+
if (type === 'display' && data && data.includes('<a')) {
25+
return data;
26+
}
27+
return data || row['Model Name'] || '';
28+
}
29+
},
30+
{ data: 'Type', title: 'Type' },
31+
{ data: 'Covariates', title: 'Covariates' },
32+
{ data: 'Multiple targets', title: 'Multiple targets' },
33+
{ data: 'Uncertainty', title: 'Uncertainty' },
34+
{ data: 'Flexible history', title: 'Flexible history' },
35+
{ data: 'Cold-start', title: 'Cold-start' },
36+
{ data: 'Compute', title: 'Compute' }
37+
],
38+
pageLength: 25,
39+
order: [[0, 'asc']],
40+
responsive: true,
41+
dom: 'lfrtip',
42+
language: {
43+
search: "Search models:",
44+
lengthMenu: "Show _MENU_ models per page",
45+
info: "Showing _START_ to _END_ of _TOTAL_ models",
46+
infoEmpty: "No models found",
47+
infoFiltered: "(filtered from _MAX_ total models)"
48+
}
49+
});
50+
51+
// Filter by type
52+
$('#type-filter').on('change', function() {
53+
var val = $(this).val();
54+
table.column(1).search(val).draw();
55+
});
56+
57+
// Filter by capability
58+
$('#capability-filter').on('change', function() {
59+
var val = $(this).val();
60+
if (val === '') {
61+
// Clear all capability filters
62+
table.columns([2, 3, 4, 5, 6]).search('').draw();
63+
} else {
64+
// Map capability name to column index
65+
var capabilityMap = {
66+
'Covariates': 2,
67+
'Multiple targets': 3,
68+
'Uncertainty': 4,
69+
'Flexible history': 5,
70+
'Cold-start': 6
71+
};
72+
73+
var colIdx = capabilityMap[val];
74+
if (colIdx !== undefined) {
75+
// Clear all capability columns first
76+
table.columns([2, 3, 4, 5, 6]).search('');
77+
// Then search in the specific column
78+
table.column(colIdx).search('✓').draw();
79+
} else {
80+
// If capability not found, search in all columns
81+
table.search(val).draw();
82+
}
83+
}
84+
});
85+
86+
// Clear filters when "All" is selected
87+
$('#type-filter, #capability-filter').on('change', function() {
88+
if ($(this).val() === '') {
89+
if ($(this).attr('id') === 'type-filter') {
90+
table.column(1).search('').draw();
91+
}
92+
}
93+
});
94+
}).fail(function(jqXHR, textStatus, errorThrown) {
95+
// Handle error loading JSON
96+
console.error('Error loading model overview data:', textStatus, errorThrown);
97+
$('#model-table').html(
98+
'<tr><td colspan="8" style="text-align: center; padding: 20px;">' +
99+
'Error loading model overview data. Please ensure the documentation was built correctly.' +
100+
'</td></tr>'
101+
);
102+
});
103+
});
104+

docs/source/conf.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,152 @@ def get_items(self, names):
132132
return new_items
133133

134134

135+
def _make_estimator_overview(app):
136+
"""Make estimator/model overview table.
137+
138+
This function generates a dynamic table of all models in pytorch-forecasting
139+
by querying the registry system. The table is written as HTML and JSON files
140+
for inclusion in the documentation.
141+
"""
142+
try:
143+
import pandas as pd
144+
from pytorch_forecasting._registry import all_objects
145+
146+
# Base classes to exclude from the overview
147+
BASE_CLASSES = {
148+
"BaseModel",
149+
"BaseModelWithCovariates",
150+
"AutoRegressiveBaseModel",
151+
"AutoRegressiveBaseModelWithCovariates",
152+
"_BaseObject",
153+
"_BasePtForecaster",
154+
"_BasePtForecasterV2",
155+
"_BasePtForecaster_Common",
156+
}
157+
158+
# Get all objects from registry
159+
all_objs = all_objects(return_names=True, suppress_import_stdout=True)
160+
161+
records = []
162+
163+
for obj_name, obj_class in all_objs:
164+
# Skip base classes
165+
if obj_name in BASE_CLASSES:
166+
continue
167+
168+
# Skip if it's not a model class (check if it has get_class_tag method)
169+
if not hasattr(obj_class, "get_class_tag"):
170+
continue
171+
172+
try:
173+
# Get model name from tags or use class name
174+
model_name = obj_class.get_class_tag("info:name", obj_name)
175+
176+
# Get authors
177+
authors = obj_class.get_class_tag("authors", None)
178+
if authors is None:
179+
authors = "pytorch-forecasting developers"
180+
elif isinstance(authors, list):
181+
authors = ", ".join(authors)
182+
183+
# Get object type
184+
object_type = obj_class.get_class_tag("object_type", "model")
185+
if isinstance(object_type, list):
186+
object_type = ", ".join(object_type)
187+
188+
# Get capabilities
189+
has_exogenous = obj_class.get_class_tag("capability:exogenous", False)
190+
has_multivariate = obj_class.get_class_tag("capability:multivariate", False)
191+
has_pred_int = obj_class.get_class_tag("capability:pred_int", False)
192+
has_flexible_history = obj_class.get_class_tag("capability:flexible_history_length", False)
193+
has_cold_start = obj_class.get_class_tag("capability:cold_start", False)
194+
195+
# Get compute requirement
196+
compute = obj_class.get_class_tag("info:compute", None)
197+
198+
# Get module path for documentation link
199+
module_path = obj_class.__module__
200+
class_name = obj_class.__name__
201+
202+
# Construct documentation link
203+
# Convert module path to API documentation path
204+
api_path = module_path.replace(".", "/")
205+
doc_link = f"api/{api_path}.html#{module_path}.{class_name}"
206+
207+
# Create model name with link
208+
model_name_link = f'<a href="{doc_link}">{model_name}</a>'
209+
210+
# Build capabilities string
211+
capabilities = []
212+
if has_exogenous:
213+
capabilities.append("Covariates")
214+
if has_multivariate:
215+
capabilities.append("Multiple targets")
216+
if has_pred_int:
217+
capabilities.append("Uncertainty")
218+
if has_flexible_history:
219+
capabilities.append("Flexible history")
220+
if has_cold_start:
221+
capabilities.append("Cold-start")
222+
223+
capabilities_str = ", ".join(capabilities) if capabilities else ""
224+
225+
records.append({
226+
"Model Name": model_name_link,
227+
"Type": object_type,
228+
"Authors": authors,
229+
"Covariates": "✓" if has_exogenous else "",
230+
"Multiple targets": "✓" if has_multivariate else "",
231+
"Uncertainty": "✓" if has_pred_int else "",
232+
"Flexible history": "✓" if has_flexible_history else "",
233+
"Cold-start": "✓" if has_cold_start else "",
234+
"Compute": str(compute) if compute is not None else "",
235+
"Capabilities": capabilities_str,
236+
"Module": module_path,
237+
})
238+
except Exception as e:
239+
# Skip objects that can't be processed
240+
print(f"Warning: Could not process {obj_name}: {e}")
241+
continue
242+
243+
if not records:
244+
print("Warning: No models found in registry")
245+
return
246+
247+
# Create DataFrame
248+
df = pd.DataFrame(records)
249+
250+
# Ensure _static directory exists
251+
static_dir = SOURCE_PATH.joinpath("_static")
252+
static_dir.mkdir(exist_ok=True)
253+
254+
# Write HTML table
255+
html_file = static_dir.joinpath("model_overview_table.html")
256+
html_content = df[["Model Name", "Type", "Covariates", "Multiple targets",
257+
"Uncertainty", "Flexible history", "Cold-start", "Compute"]].to_html(
258+
classes="model-overview-table", index=False, border=0, escape=False
259+
)
260+
html_file.write_text(html_content, encoding="utf-8")
261+
print(f"Generated model overview table: {html_file}")
262+
263+
# Write JSON database for interactive filtering (optional)
264+
json_file = static_dir.joinpath("model_overview_db.json")
265+
df.to_json(json_file, orient="records", indent=2)
266+
print(f"Generated model overview JSON: {json_file}")
267+
268+
except ImportError as e:
269+
print(f"Warning: Could not generate model overview (missing dependency): {e}")
270+
except Exception as e:
271+
print(f"Warning: Error generating model overview: {e}")
272+
273+
135274
def setup(app: Sphinx):
136275
app.add_css_file("custom.css")
137276
app.connect("autodoc-skip-member", skip)
138277
app.add_directive("moduleautosummary", ModuleAutoSummary)
139278
app.add_js_file("https://buttons.github.io/buttons.js", **{"async": "async"})
279+
# Connect model overview generator to builder-inited event
280+
app.connect("builder-inited", _make_estimator_overview)
140281

141282

142283
# extension configuration

0 commit comments

Comments
 (0)