From 12c35c47a076abc499b92c69ca542cdfe1eda87c Mon Sep 17 00:00:00 2001 From: ClF3 Date: Wed, 16 Oct 2024 13:00:46 +0800 Subject: [PATCH] move the whole dataset to gpu --- CIFAR10_playground.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/CIFAR10_playground.py b/CIFAR10_playground.py index b189ac3..ef8c6f1 100755 --- a/CIFAR10_playground.py +++ b/CIFAR10_playground.py @@ -8,7 +8,7 @@ import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt - +device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ################### 数据集初始化与读入 ################### train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), @@ -20,6 +20,10 @@ train_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=True,download=F test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0) test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0) +train_dset.train_data.to(device) +train_dset.train_labels.to(device) +test_dset.test_data.to(device) +test_dset.test_labels.to(device) ####################################################### @@ -223,7 +227,7 @@ class BnDeepNet(nn.Module): # model = DeepNet('tanh') model = BnDeepNet('relu') -device=torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) criterion = nn.CrossEntropyLoss() @@ -248,7 +252,7 @@ for epoch in range(n_epochs): valid_loss = 0.0 model.train() for idx,(img,label) in tqdm(enumerate(train_loader)): - img, label=img.to(device), label.to(device) + # img, label=img.to(device), label.to(device) optimizer.zero_grad() output = model(img) loss = criterion(output,label) @@ -260,7 +264,7 @@ for epoch in range(n_epochs): correct = 0 total = 0 for idx,(img,label) in tqdm(enumerate(test_loader)): - img, label=img.to(device), label.to(device) + # img, label=img.to(device), label.to(device) output = model(img) loss = criterion(output, label) valid_loss += loss.item() * img.shape[0]