Skip to content

Commit b4ec4cb

Browse files
authored
Update common used toy linear model (#3275)
* build common used toy linear model Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> * update model to use direct input * revert unit test skip
1 parent 9e93ab1 commit b4ec4cb

File tree

2 files changed

+66
-20
lines changed

2 files changed

+66
-20
lines changed

test/sparsity/test_fast_sparse_training.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,10 @@
1515
swap_linear_with_semi_sparse_linear,
1616
swap_semi_sparse_linear_with_linear,
1717
)
18+
from torchao.testing.model_architectures import ToyTwoLinearModel
1819
from torchao.utils import is_fbcode
1920

2021

21-
class ToyModel(nn.Module):
22-
def __init__(self):
23-
super().__init__()
24-
self.linear1 = nn.Linear(128, 256, bias=False)
25-
self.linear2 = nn.Linear(256, 128, bias=False)
26-
27-
def forward(self, x):
28-
x = self.linear1(x)
29-
x = torch.nn.functional.relu(x)
30-
x = self.linear2(x)
31-
return x
32-
33-
3422
class TestRuntimeSemiStructuredSparsity(TestCase):
3523
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
3624
@unittest.skipIf(is_fbcode(), "broken in fbcode")
@@ -41,7 +29,7 @@ def test_runtime_weight_sparsification(self):
4129

4230
input = torch.rand((128, 128)).half().cuda()
4331
grad = torch.rand((128, 128)).half().cuda()
44-
model = ToyModel().half().cuda()
32+
model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16)
4533
model_c = copy.deepcopy(model)
4634

4735
for name, mod in model.named_modules():
@@ -89,7 +77,7 @@ def test_runtime_weight_sparsification_compile(self):
8977

9078
input = torch.rand((128, 128)).half().cuda()
9179
grad = torch.rand((128, 128)).half().cuda()
92-
model = ToyModel().half().cuda()
80+
model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16)
9381
model_c = copy.deepcopy(model)
9482

9583
for name, mod in model.named_modules():

torchao/testing/model_architectures.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,72 @@
1111
import torch.nn.functional as F
1212

1313

14+
class ToySingleLinearModel(torch.nn.Module):
15+
def __init__(
16+
self,
17+
input_dim,
18+
output_dim,
19+
dtype,
20+
device,
21+
has_bias=False,
22+
):
23+
super().__init__()
24+
self.dtype = dtype
25+
self.device = device
26+
self.linear1 = torch.nn.Linear(
27+
input_dim, output_dim, bias=has_bias, dtype=dtype, device=device
28+
)
29+
30+
def example_inputs(self, batch_size=1):
31+
return (
32+
torch.randn(
33+
batch_size,
34+
self.linear1.in_features,
35+
dtype=self.dtype,
36+
device=self.device,
37+
),
38+
)
39+
40+
def forward(self, x):
41+
x = self.linear1(x)
42+
return x
43+
44+
1445
# TODO: Refactor torchao and tests to use these models
15-
class ToyLinearModel(torch.nn.Module):
16-
def __init__(self, k=64, n=32, dtype=torch.bfloat16):
46+
class ToyTwoLinearModel(torch.nn.Module):
47+
def __init__(
48+
self,
49+
input_dim,
50+
hidden_dim,
51+
output_dim,
52+
dtype,
53+
device,
54+
has_bias=False,
55+
):
1756
super().__init__()
18-
self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype)
57+
self.dtype = dtype
58+
self.device = device
59+
self.linear1 = torch.nn.Linear(
60+
input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device
61+
)
62+
self.linear2 = torch.nn.Linear(
63+
hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device
64+
)
65+
66+
# Note: Tiny-GEMM kernel only uses BF16 inputs
67+
def example_inputs(self, batch_size=1):
68+
return (
69+
torch.randn(
70+
batch_size,
71+
self.linear1.in_features,
72+
dtype=self.dtype,
73+
device=self.device,
74+
),
75+
)
1976

2077
def forward(self, x):
2178
x = self.linear1(x)
79+
x = self.linear2(x)
2280
return x
2381

2482

@@ -179,8 +237,8 @@ def create_model_and_input_data(
179237
m, k, n (int): dimensions of the model and input data
180238
"""
181239
if model_type == "linear":
182-
model = ToyLinearModel(k, n, high_precision_dtype).to(device)
183-
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)
240+
model = ToySingleLinearModel(k, n, device=device, dtype=high_precision_dtype)
241+
input_data = model.example_inputs(batch_size=m)[0]
184242
elif "ln_linear" in model_type:
185243
# Extract activation type from model_type string
186244
match = re.search(r"ln_linear_?(\w+)?", model_type)

0 commit comments

Comments
 (0)