diff --git a/mnist_conv/vit_mnist.py b/mnist_conv/vit_mnist.py index 41561c1..1d891f1 100644 --- a/mnist_conv/vit_mnist.py +++ b/mnist_conv/vit_mnist.py @@ -233,7 +233,7 @@ def train(): valid_ratio = 0.1 batch_size = 6 epochs = 10 - cdtype = torch.complex64 + cdtype = torch.float32 # Dataloading train_valid_dataset = torchvision.datasets.MNIST( @@ -462,4 +462,4 @@ def lightning_train(version: int): if __name__ == "__main__": - lightning_train(0) \ No newline at end of file + train() \ No newline at end of file