From 768cdb1acc2861723aa075e2bc3dad7f3714d624 Mon Sep 17 00:00:00 2001 From: ClF3 Date: Wed, 16 Oct 2024 10:45:03 +0800 Subject: [PATCH] bug fix --- CIFAR10_playground.py | 1 + 1 file changed, 1 insertion(+) diff --git a/CIFAR10_playground.py b/CIFAR10_playground.py index 4f01afa..b189ac3 100755 --- a/CIFAR10_playground.py +++ b/CIFAR10_playground.py @@ -260,6 +260,7 @@ for epoch in range(n_epochs): correct = 0 total = 0 for idx,(img,label) in tqdm(enumerate(test_loader)): + img, label=img.to(device), label.to(device) output = model(img) loss = criterion(output, label) valid_loss += loss.item() * img.shape[0]