We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2b3a909 commit 48a26deCopy full SHA for 48a26de
src/model_constructor/convmixer.py
@@ -5,16 +5,16 @@
5
from collections import OrderedDict
6
from typing import Callable, List, Optional, Union
7
8
+import torch
9
import torch.nn as nn
-from torch import TensorType
10
11
12
class Residual(nn.Module):
13
- def __init__(self, fn: Callable[[TensorType], TensorType]):
+ def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
14
super().__init__()
15
self.fn = fn
16
17
- def forward(self, x: TensorType) -> TensorType:
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
return self.fn(x) + x
19
20
0 commit comments