Skip to content

Commit 48a26de

Browse files
committed
typing convmixer
1 parent 2b3a909 commit 48a26de

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/model_constructor/convmixer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from collections import OrderedDict
66
from typing import Callable, List, Optional, Union
77

8+
import torch
89
import torch.nn as nn
9-
from torch import TensorType
1010

1111

1212
class Residual(nn.Module):
13-
def __init__(self, fn: Callable[[TensorType], TensorType]):
13+
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
1414
super().__init__()
1515
self.fn = fn
1616

17-
def forward(self, x: TensorType) -> TensorType:
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1818
return self.fn(x) + x
1919

2020

0 commit comments

Comments
 (0)