move the whole dataset to gpu

This commit is contained in:
ClF3 2024-10-16 13:00:46 +08:00
parent 768cdb1acc
commit 12c35c47a0
1 changed files with 8 additions and 4 deletions

View File

@ -8,7 +8,7 @@ import torch.optim as optim
import torchvision.transforms as transforms import torchvision.transforms as transforms
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
################### 数据集初始化与读入 ################### ################### 数据集初始化与读入 ###################
train_transform = transforms.Compose([ train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
@ -20,6 +20,10 @@ train_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=True,download=F
test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor()) test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0) train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0) test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0)
train_dset.train_data.to(device)
train_dset.train_labels.to(device)
test_dset.test_data.to(device)
test_dset.test_labels.to(device)
####################################################### #######################################################
@ -223,7 +227,7 @@ class BnDeepNet(nn.Module):
# model = DeepNet('tanh') # model = DeepNet('tanh')
model = BnDeepNet('relu') model = BnDeepNet('relu')
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) model.to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
@ -248,7 +252,7 @@ for epoch in range(n_epochs):
valid_loss = 0.0 valid_loss = 0.0
model.train() model.train()
for idx,(img,label) in tqdm(enumerate(train_loader)): for idx,(img,label) in tqdm(enumerate(train_loader)):
img, label=img.to(device), label.to(device) # img, label=img.to(device), label.to(device)
optimizer.zero_grad() optimizer.zero_grad()
output = model(img) output = model(img)
loss = criterion(output,label) loss = criterion(output,label)
@ -260,7 +264,7 @@ for epoch in range(n_epochs):
correct = 0 correct = 0
total = 0 total = 0
for idx,(img,label) in tqdm(enumerate(test_loader)): for idx,(img,label) in tqdm(enumerate(test_loader)):
img, label=img.to(device), label.to(device) # img, label=img.to(device), label.to(device)
output = model(img) output = model(img)
loss = criterion(output, label) loss = criterion(output, label)
valid_loss += loss.item() * img.shape[0] valid_loss += loss.item() * img.shape[0]