From 3deb418e22cb95172e1c4bdd2e25541f97f62dad Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 1 Nov 2025 23:48:24 +0900 Subject: [PATCH 1/3] build common used toy linear model Co-authored-by: Jerry Zhang --- test/sparsity/test_fast_sparse_training.py | 20 +------ torchao/testing/model_architectures.py | 66 ++++++++++++++++++++-- 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 424306f897..7448e8181b 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,33 +15,20 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.utils import is_fbcode -class ToyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(128, 256, bias=False) - self.linear2 = nn.Linear(256, 128, bias=False) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.functional.relu(x) - x = self.linear2(x) - return x - - class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") - @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): @@ -82,14 +69,13 @@ def test_runtime_weight_sparsification(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") - @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index 8f41a8464c..fb07bb37b0 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -11,14 +11,72 @@ import torch.nn.functional as F +class ToySingleLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + dtype, + device, + has_bias=False, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) + + def forward(self, x): + x = self.linear1(x) + return x + + # TODO: Refactor torchao and tests to use these models -class ToyLinearModel(torch.nn.Module): - def __init__(self, k=64, n=32, dtype=torch.bfloat16): +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + dtype, + device, + has_bias=False, + ): super().__init__() - self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + # Note: Tiny-GEMM kernel only uses BF16 inputs + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) def forward(self, x): x = self.linear1(x) + x = self.linear2(x) return x @@ -179,7 +237,7 @@ def create_model_and_input_data( m, k, n (int): dimensions of the model and input data """ if model_type == "linear": - model = ToyLinearModel(k, n, high_precision_dtype).to(device) + model = ToySingleLinearModel(k, n, device=device, dtype=high_precision_dtype) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) elif "ln_linear" in model_type: # Extract activation type from model_type string From 5e2310a69102edbb951835c4adcd5dd1a3b80fc5 Mon Sep 17 00:00:00 2001 From: youn17 Date: Wed, 5 Nov 2025 15:38:10 +0900 Subject: [PATCH 2/3] update model to use direct input --- test/sparsity/test_fast_sparse_training.py | 1 + torchao/testing/model_architectures.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 7448e8181b..1dd7faf2fa 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -22,6 +22,7 @@ class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") + @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index fb07bb37b0..4100a3cd76 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -238,7 +238,7 @@ def create_model_and_input_data( """ if model_type == "linear": model = ToySingleLinearModel(k, n, device=device, dtype=high_precision_dtype) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + input_data = model.example_inputs(batch_size=m)[0] elif "ln_linear" in model_type: # Extract activation type from model_type string match = re.search(r"ln_linear_?(\w+)?", model_type) From 92ba186abba04efd25048ce3dbbaef396a6656fe Mon Sep 17 00:00:00 2001 From: younn17 Date: Wed, 5 Nov 2025 15:42:34 +0900 Subject: [PATCH 3/3] revert unit test skip --- test/sparsity/test_fast_sparse_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 1dd7faf2fa..a9f57bb5a5 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -70,6 +70,7 @@ def test_runtime_weight_sparsification(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") + @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT