move the whole dataset to gpu
This commit is contained in:
parent
768cdb1acc
commit
12c35c47a0
|
@ -8,7 +8,7 @@ import torch.optim as optim
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
################### 数据集初始化与读入 ###################
|
################### 数据集初始化与读入 ###################
|
||||||
train_transform = transforms.Compose([
|
train_transform = transforms.Compose([
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip(),
|
||||||
|
@ -20,6 +20,10 @@ train_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=True,download=F
|
||||||
test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor())
|
test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor())
|
||||||
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0)
|
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0)
|
||||||
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0)
|
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0)
|
||||||
|
train_dset.train_data.to(device)
|
||||||
|
train_dset.train_labels.to(device)
|
||||||
|
test_dset.test_data.to(device)
|
||||||
|
test_dset.test_labels.to(device)
|
||||||
#######################################################
|
#######################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ -223,7 +227,7 @@ class BnDeepNet(nn.Module):
|
||||||
# model = DeepNet('tanh')
|
# model = DeepNet('tanh')
|
||||||
model = BnDeepNet('relu')
|
model = BnDeepNet('relu')
|
||||||
|
|
||||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
@ -248,7 +252,7 @@ for epoch in range(n_epochs):
|
||||||
valid_loss = 0.0
|
valid_loss = 0.0
|
||||||
model.train()
|
model.train()
|
||||||
for idx,(img,label) in tqdm(enumerate(train_loader)):
|
for idx,(img,label) in tqdm(enumerate(train_loader)):
|
||||||
img, label=img.to(device), label.to(device)
|
# img, label=img.to(device), label.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = model(img)
|
output = model(img)
|
||||||
loss = criterion(output,label)
|
loss = criterion(output,label)
|
||||||
|
@ -260,7 +264,7 @@ for epoch in range(n_epochs):
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
for idx,(img,label) in tqdm(enumerate(test_loader)):
|
for idx,(img,label) in tqdm(enumerate(test_loader)):
|
||||||
img, label=img.to(device), label.to(device)
|
# img, label=img.to(device), label.to(device)
|
||||||
output = model(img)
|
output = model(img)
|
||||||
loss = criterion(output, label)
|
loss = criterion(output, label)
|
||||||
valid_loss += loss.item() * img.shape[0]
|
valid_loss += loss.item() * img.shape[0]
|
||||||
|
|
Loading…
Reference in New Issue