This commit is contained in:
ClF3 2024-10-16 13:44:10 +08:00
parent 3376c8a0d5
commit bb7b0cd7cd
1 changed files with 2 additions and 6 deletions

View File

@ -20,10 +20,6 @@ train_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=True,download=F
test_dset = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=False,transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=0)
train_dset.data.to(device)
train_dset.target.to(device)
test_dset.data.to(device)
test_dset.target.to(device)
#######################################################
@ -252,7 +248,7 @@ for epoch in range(n_epochs):
valid_loss = 0.0
model.train()
for idx,(img,label) in tqdm(enumerate(train_loader)):
# img, label=img.to(device), label.to(device)
img, label=img.to(device), label.to(device)
optimizer.zero_grad()
output = model(img)
loss = criterion(output,label)
@ -264,7 +260,7 @@ for epoch in range(n_epochs):
correct = 0
total = 0
for idx,(img,label) in tqdm(enumerate(test_loader)):
# img, label=img.to(device), label.to(device)
img, label=img.to(device), label.to(device)
output = model(img)
loss = criterion(output, label)
valid_loss += loss.item() * img.shape[0]