move the whole dataset to gpu
This commit is contained in:
		
							parent
							
								
									768cdb1acc
								
							
						
					
					
						commit
						12c35c47a0
					
				| 
						 | 
				
			
			@ -8,7 +8,7 @@ import torch.optim as optim
 | 
			
		|||
import torchvision.transforms as transforms
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
################### 数据集初始化与读入 ###################
 | 
			
		||||
train_transform = transforms.Compose([
 | 
			
		||||
    transforms.RandomHorizontalFlip(),
 | 
			
		||||
| 
						 | 
				
			
			@ -20,6 +20,10 @@ train_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=True,download=F
 | 
			
		|||
test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor())
 | 
			
		||||
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0)
 | 
			
		||||
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0)
 | 
			
		||||
train_dset.train_data.to(device)
 | 
			
		||||
train_dset.train_labels.to(device)
 | 
			
		||||
test_dset.test_data.to(device)
 | 
			
		||||
test_dset.test_labels.to(device)
 | 
			
		||||
#######################################################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -223,7 +227,7 @@ class BnDeepNet(nn.Module):
 | 
			
		|||
# model = DeepNet('tanh')
 | 
			
		||||
model = BnDeepNet('relu')
 | 
			
		||||
 | 
			
		||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
 | 
			
		||||
model.to(device)
 | 
			
		||||
criterion = nn.CrossEntropyLoss()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -248,7 +252,7 @@ for epoch in range(n_epochs):
 | 
			
		|||
    valid_loss = 0.0
 | 
			
		||||
    model.train()
 | 
			
		||||
    for idx,(img,label) in tqdm(enumerate(train_loader)):
 | 
			
		||||
        img, label=img.to(device), label.to(device)
 | 
			
		||||
        # img, label=img.to(device), label.to(device)
 | 
			
		||||
        optimizer.zero_grad()
 | 
			
		||||
        output = model(img)
 | 
			
		||||
        loss = criterion(output,label)
 | 
			
		||||
| 
						 | 
				
			
			@ -260,7 +264,7 @@ for epoch in range(n_epochs):
 | 
			
		|||
    correct = 0
 | 
			
		||||
    total = 0
 | 
			
		||||
    for idx,(img,label) in tqdm(enumerate(test_loader)):
 | 
			
		||||
        img, label=img.to(device), label.to(device)
 | 
			
		||||
        # img, label=img.to(device), label.to(device)
 | 
			
		||||
        output = model(img)
 | 
			
		||||
        loss = criterion(output, label)
 | 
			
		||||
        valid_loss += loss.item() * img.shape[0]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue