From bb7b0cd7cd5ce751fa025018cda1b4a451fbee29 Mon Sep 17 00:00:00 2001 From: ClF3 Date: Wed, 16 Oct 2024 13:44:10 +0800 Subject: [PATCH] fix --- CIFAR10_playground.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/CIFAR10_playground.py b/CIFAR10_playground.py index c725731..171140a 100755 --- a/CIFAR10_playground.py +++ b/CIFAR10_playground.py @@ -20,10 +20,6 @@ train_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=True,download=F test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0) test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0) -train_dset.data.to(device) -train_dset.target.to(device) -test_dset.data.to(device) -test_dset.target.to(device) ####################################################### @@ -252,7 +248,7 @@ for epoch in range(n_epochs): valid_loss = 0.0 model.train() for idx,(img,label) in tqdm(enumerate(train_loader)): - # img, label=img.to(device), label.to(device) + img, label=img.to(device), label.to(device) optimizer.zero_grad() output = model(img) loss = criterion(output,label) @@ -264,7 +260,7 @@ for epoch in range(n_epochs): correct = 0 total = 0 for idx,(img,label) in tqdm(enumerate(test_loader)): - # img, label=img.to(device), label.to(device) + img, label=img.to(device), label.to(device) output = model(img) loss = criterion(output, label) valid_loss += loss.item() * img.shape[0]