diff --git a/CIFAR10_playground.py b/CIFAR10_playground.py index 404637e..4f01afa 100755 --- a/CIFAR10_playground.py +++ b/CIFAR10_playground.py @@ -220,8 +220,8 @@ class BnDeepNet(nn.Module): ################### 训练前准备 ################### # model = Net('tanh') # model = BnNet('relu') -model = DeepNet('tanh') -# model = BnDeepNet('relu') +# model = DeepNet('tanh') +model = BnDeepNet('relu') device=torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)