Skip to content

Commit 31c1132

Browse files
authored
PyTorch geometric quantization support (#494)
## What does this PR do? **Type of change:** New feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Support quantization of PyTorch Geometric ```python # Add a code snippet demonstrating how to use this ``` ## Testing `python -m pytest tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py -v` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes<!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information - #298 --------- Signed-off-by: Riyad Islam <rislam@nvidia.com>
1 parent 7d0f7a9 commit 31c1132

File tree

5 files changed

+264
-0
lines changed

5 files changed

+264
-0
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
Model Optimizer Changelog (Linux)
22
=================================
3+
0.41 (2025-12-xx)
4+
^^^^^^^^^^^^^^^^^
5+
6+
**Deprecations**
7+
8+
**New Features**
9+
- Add support for PyTorch Geometric quantization.
310

411
0.40 (2025-12-xx)
512
^^^^^^^^^^^^^^^^^

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- :meth:`huggingface<modelopt.torch.quantization.plugins.huggingface>`
2626
- :meth:`megatron<modelopt.torch.quantization.plugins.megatron>`
2727
- :meth:`peft<modelopt.torch.quantization.plugins.peft>`
28+
- :meth:`pytorch_geometric<modelopt.torch.quantization.plugins.pytorch_geometric>`
2829
- :meth:`transformer_engine<modelopt.torch.quantization.plugins.transformer_engine>`
2930
"""
3031

@@ -57,6 +58,9 @@
5758
with import_plugin("peft"):
5859
from .peft import *
5960

61+
with import_plugin("torch_geometric"):
62+
from .pytorch_geometric import *
63+
6064
with import_plugin("transformer_engine"):
6165
from .transformer_engine import *
6266

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""PyTorch Geometric quantization plugin.
17+
18+
This plugin enables quantization support for PyTorch Geometric (PyG) layers by registering
19+
PyG's custom Linear layer with ModelOpt's quantization registry.
20+
21+
Example:
22+
>>> import modelopt.torch.quantization as mtq
23+
>>> from torch_geometric.nn import GATConv
24+
>>>
25+
>>> # Create a model with PyG layers
26+
>>> class GATModel(nn.Module):
27+
... def __init__(self):
28+
... super().__init__()
29+
... self.gat1 = GATConv(10, 64, heads=4)
30+
... self.gat2 = GATConv(64 * 4, 32, heads=1)
31+
>>> model = GATModel()
32+
>>> # PyG layers are now automatically quantizable!
33+
>>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
34+
"""
35+
36+
from torch_geometric.nn.dense.linear import Linear as PyGLinear
37+
38+
from modelopt.torch.quantization.nn.modules.quant_module import (
39+
QuantLinearConvBase,
40+
QuantModuleRegistry,
41+
)
42+
from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
43+
44+
45+
class QuantPyGLinear(QuantLinearConvBase):
46+
"""Quantized version of PyTorch Geometric's Linear layer.
47+
48+
PyTorch Geometric uses a custom Linear layer that is functionally equivalent to
49+
torch.nn.Linear but has a different API (in_channels/out_channels instead of
50+
in_features/out_features). This class enables quantization of PyG Linear layers
51+
by inheriting from QuantLinearConvBase, which handles all quantization logic.
52+
53+
The quantization is handled automatically by the base classes:
54+
- Input quantization: Handled by QuantInputBase.forward()
55+
- Weight quantization: Handled by QuantLinearConvBase's dynamic weight attribute
56+
- Output quantization: Handled by QuantInputBase.forward()
57+
58+
Note:
59+
Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.)
60+
internally use PyG Linear layers, so registering this class enables quantization
61+
for a wide range of graph neural network layers.
62+
"""
63+
64+
default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
65+
66+
67+
QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(QuantPyGLinear)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"pytest-timeout",
8080
"timm",
8181
"torchvision",
82+
"torch-geometric",
8283
"tox>4.18",
8384
"tox-current-env>=0.0.12",
8485
],
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for PyTorch Geometric quantization plugin."""
17+
18+
import pytest
19+
import torch
20+
import torch.nn as nn
21+
from _test_utils.torch.misc import set_seed
22+
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv
23+
24+
import modelopt.torch.quantization as mtq
25+
26+
27+
class TestPyTorchGeometricPlugin:
28+
"""Test PyTorch Geometric quantization support."""
29+
30+
@pytest.fixture(autouse=True)
31+
def setup_seed(self):
32+
"""Set seed before each test function."""
33+
set_seed()
34+
35+
@pytest.fixture
36+
def device(self):
37+
"""Get test device."""
38+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
39+
40+
def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"):
41+
"""Create sample graph data for testing."""
42+
x = torch.randn(batch_size * num_nodes, in_channels, device=device)
43+
# Create batch assignment
44+
batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)])
45+
46+
# Create edge indices for each graph
47+
edge_list = []
48+
offset = 0
49+
for _ in range(batch_size):
50+
# Create random edges within each graph
51+
src = torch.randint(0, num_nodes, (50,), device=device) + offset
52+
dst = torch.randint(0, num_nodes, (50,), device=device) + offset
53+
edge_list.append(torch.stack([src, dst]))
54+
offset += num_nodes
55+
56+
edge_index = torch.cat(edge_list, dim=1)
57+
edge_attr = torch.randn(edge_index.size(1), 32, device=device)
58+
59+
return x, edge_index, edge_attr, batch
60+
61+
def test_gat_conv_quantization(self, device):
62+
"""Test GATConv layer quantization."""
63+
64+
class GATModel(nn.Module):
65+
def __init__(self):
66+
super().__init__()
67+
self.gat1 = GATConv(16, 64, heads=4, edge_dim=32)
68+
self.gat2 = GATConv(256, 32, heads=1, edge_dim=32)
69+
70+
def forward(self, x, edge_index, edge_attr):
71+
x = torch.relu(self.gat1(x, edge_index, edge_attr))
72+
return self.gat2(x, edge_index, edge_attr)
73+
74+
model = GATModel().to(device)
75+
76+
# Calibration function
77+
def calibrate(m):
78+
m.eval()
79+
with torch.no_grad():
80+
for _ in range(5):
81+
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
82+
_ = m(x, edge_index, edge_attr)
83+
84+
# Quantize model
85+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
86+
87+
# Verify quantization
88+
quantizer_count = sum(
89+
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
90+
)
91+
assert quantizer_count > 0, "No quantizers were inserted"
92+
93+
# Test forward pass
94+
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
95+
with torch.no_grad():
96+
output = quantized(x, edge_index, edge_attr)
97+
assert output is not None
98+
99+
def test_multiple_layer_types(self, device):
100+
"""Test quantization of multiple PyG layer types."""
101+
102+
class MultiLayerGNN(nn.Module):
103+
def __init__(self):
104+
super().__init__()
105+
self.gcn = GCNConv(16, 32)
106+
self.sage = SAGEConv(32, 64)
107+
self.transformer = TransformerConv(64, 32, heads=2)
108+
109+
def forward(self, x, edge_index):
110+
x = torch.relu(self.gcn(x, edge_index))
111+
x = torch.relu(self.sage(x, edge_index))
112+
return self.transformer(x, edge_index)
113+
114+
model = MultiLayerGNN().to(device)
115+
116+
# Calibration
117+
def calibrate(m):
118+
m.eval()
119+
with torch.no_grad():
120+
for _ in range(3):
121+
x = torch.randn(50, 16, device=device)
122+
edge_index = torch.randint(0, 50, (2, 100), device=device)
123+
_ = m(x, edge_index)
124+
125+
# Quantize
126+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
127+
128+
# Check that PyG Linear layers were quantized
129+
pyg_linear_count = 0
130+
for name, module in model.named_modules():
131+
if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)):
132+
pyg_linear_count += 1
133+
134+
quantizer_count = sum(
135+
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
136+
)
137+
138+
# Each PyG linear should have at least 2 quantizers (input, weight)
139+
assert quantizer_count >= pyg_linear_count * 2, (
140+
f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}"
141+
)
142+
143+
def test_quantization_accuracy(self, device):
144+
"""Test that quantization maintains reasonable accuracy."""
145+
# Set seed for this test specifically to ensure reproducibility
146+
set_seed()
147+
148+
model = GATConv(16, 32, heads=2, edge_dim=16).to(device)
149+
150+
# Create test data
151+
x, edge_index, edge_attr, _ = self.create_graph_data(
152+
batch_size=1, in_channels=16, device=device
153+
)
154+
edge_attr = edge_attr[:, :16] # Match edge_dim
155+
156+
# Get original output
157+
model.eval()
158+
with torch.no_grad():
159+
original_output = model(x, edge_index, edge_attr)
160+
161+
# Calibration with multiple samples for more stable quantization
162+
def calibrate(m):
163+
m.eval()
164+
with torch.no_grad():
165+
# Use multiple calibration samples for better stability
166+
for _ in range(5):
167+
x_cal, edge_index_cal, edge_attr_cal, _ = self.create_graph_data(
168+
batch_size=1, in_channels=16, device=device
169+
)
170+
edge_attr_cal = edge_attr_cal[:, :16] # Match edge_dim
171+
_ = m(x_cal, edge_index_cal, edge_attr_cal)
172+
173+
# Quantize
174+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
175+
176+
# Get quantized output
177+
with torch.no_grad():
178+
quantized_output = quantized(x, edge_index, edge_attr)
179+
180+
# Check relative error
181+
abs_diff = torch.abs(original_output - quantized_output)
182+
relative_error = abs_diff / (torch.abs(original_output) + 1e-8)
183+
mean_relative_error = relative_error.mean().item()
184+
185+
assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}"

0 commit comments

Comments
 (0)