博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
RuntimeError: output with shape [1, 28, 28] doesnt match the broadcast shape [3, 28, 28]
阅读量:4145 次
发布时间:2019-05-25

本文共 3470 字,大约阅读时间需要 11 分钟。

pytorch执行MNIST源码# Import things like usual%matplotlib inline%config InlineBackend.figure_format = 'retina'import numpy as npimport torchimport helperimport matplotlib.pyplot as pltfrom torchvision import datasets, transforms# Define a transform to normalize the datatransform = transforms.Compose([transforms.ToTensor(),                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),                             ])# Download and load the training datatrainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)# Download and load the test datatestset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)dataiter = iter(trainloader)images, labels = dataiter.next()#报错如下RuntimeError                              Traceback (most recent call last)
in
1 dataiter = iter(trainloader)----> 2 images, labels = dataiter.next()C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self) 558 if self.num_workers == 0: # same-process loading 559 indices = next(self.sample_iter) # may raise StopIteration--> 560 batch = self.collate_fn([self.dataset[i] for i in indices]) 561 if self.pin_memory: 562 batch = _utils.pin_memory.pin_memory_batch(batch)C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in
(.0) 558 if self.num_workers == 0: # same-process loading 559 indices = next(self.sample_iter) # may raise StopIteration--> 560 batch = self.collate_fn([self.dataset[i] for i in indices]) 561 if self.pin_memory: 562 batch = _utils.pin_memory.pin_memory_batch(batch)C:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in __getitem__(self, index) 93 94 if self.transform is not None:---> 95 img = self.transform(img) 96 97 if self.target_transform is not None:C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img) 59 def __call__(self, img): 60 for t in self.transforms:---> 61 img = t(img) 62 return img 63 C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, tensor) 162 Tensor: Normalized Tensor image. 163 """--> 164 return F.normalize(tensor, self.mean, self.std, self.inplace) 165 166 def __repr__(self):C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\functional.py in normalize(tensor, mean, std, inplace) 206 mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device) 207 std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)--> 208 tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) 209 return tensor 210 RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

 

解决办法:

将三通道的标准化改为1通道的,因为使用的图片集是1通道的,如下

#transform = transforms.Compose([transforms.ToTensor(),

#                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#                             ])
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,),(0.5,)),
                             ])

#解决

转载地址:http://lpfti.baihongyu.com/

你可能感兴趣的文章
[学习笔记]人工智能-感知器分类算法
查看>>
Java基础 05 For循环-99乘法口诀
查看>>
Android 系统获取 CPU 位数信息
查看>>
Python基础-标准数据类型
查看>>
获取手机品牌信息的Build类
查看>>
Python基础-成员运算符和身份运算符
查看>>
Python基础-数字(Number)
查看>>
[学习笔记]人工智能-数据解析和可视化
查看>>
飞行模式关闭modem改善待机功耗
查看>>
Android系统的SDK Version code
查看>>
[学习笔记]人工智能-神经网络对数据进行分类,构建二维矩阵
查看>>
[学习笔记]人工智能-神经网络对数据进行分类-可视化
查看>>
手机平台信息(高通或者MTK平台)
查看>>
[学习笔记]适应行线性神经元基本原理
查看>>
Android系统 getProp信息导出
查看>>
Android获取总内存和可用内存
查看>>
Android 获取系统信息获取
查看>>
[学习笔记]适应行线性神经元基本原理
查看>>
Android系统获取GPU、屏幕信息
查看>>
Android 系统 CPU 基本信息
查看>>