From 5721d4d8ccc795da8fb4fce427583c57815af203 Mon Sep 17 00:00:00 2001 From: "huy.nguyen" Date: Sun, 23 Nov 2025 21:32:21 +0100 Subject: [PATCH] fix cdtype default to float32 --- mnist_conv/vit_mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mnist_conv/vit_mnist.py b/mnist_conv/vit_mnist.py index 41561c1..1d891f1 100644 --- a/mnist_conv/vit_mnist.py +++ b/mnist_conv/vit_mnist.py @@ -233,7 +233,7 @@ def train(): valid_ratio = 0.1 batch_size = 6 epochs = 10 - cdtype = torch.complex64 + cdtype = torch.float32 # Dataloading train_valid_dataset = torchvision.datasets.MNIST( @@ -462,4 +462,4 @@ def lightning_train(version: int): if __name__ == "__main__": - lightning_train(0) \ No newline at end of file + train() \ No newline at end of file