本文共 3470 字,大约阅读时间需要 11 分钟。
pytorch执行MNIST源码# Import things like usual%matplotlib inline%config InlineBackend.figure_format = 'retina'import numpy as npimport torchimport helperimport matplotlib.pyplot as pltfrom torchvision import datasets, transforms# Define a transform to normalize the datatransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])# Download and load the training datatrainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)# Download and load the test datatestset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)dataiter = iter(trainloader)images, labels = dataiter.next()#报错如下RuntimeError Traceback (most recent call last)in 1 dataiter = iter(trainloader)----> 2 images, labels = dataiter.next()C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self) 558 if self.num_workers == 0: # same-process loading 559 indices = next(self.sample_iter) # may raise StopIteration--> 560 batch = self.collate_fn([self.dataset[i] for i in indices]) 561 if self.pin_memory: 562 batch = _utils.pin_memory.pin_memory_batch(batch)C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in (.0) 558 if self.num_workers == 0: # same-process loading 559 indices = next(self.sample_iter) # may raise StopIteration--> 560 batch = self.collate_fn([self.dataset[i] for i in indices]) 561 if self.pin_memory: 562 batch = _utils.pin_memory.pin_memory_batch(batch)C:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in __getitem__(self, index) 93 94 if self.transform is not None:---> 95 img = self.transform(img) 96 97 if self.target_transform is not None:C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img) 59 def __call__(self, img): 60 for t in self.transforms:---> 61 img = t(img) 62 return img 63 C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, tensor) 162 Tensor: Normalized Tensor image. 163 """--> 164 return F.normalize(tensor, self.mean, self.std, self.inplace) 165 166 def __repr__(self):C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\functional.py in normalize(tensor, mean, std, inplace) 206 mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device) 207 std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)--> 208 tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) 209 return tensor 210 RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
解决办法:
将三通道的标准化改为1通道的,因为使用的图片集是1通道的,如下
#transform = transforms.Compose([transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # ]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,)), ])#解决
转载地址:http://lpfti.baihongyu.com/