- import torch
- from torch import nn
- from torch.nn import functional as F
- from torch import optim
- import torchvision
- from matplotlib import pyplot as plt
- # 小工具
- def plot_curve(data):
- fig = plt.figure()
- plt.plot(range(len(data)),data,color='blue')
- plt.legend(['value'],loc='upper right')
- plt.xlabel('step')
- plt.tlabel('value')
- plt.show()
- def plot_image(img,label,name):
- fig = plt.figure()
- for i in range(6):
- plt.subplot(2,3,i+1)
- plt,tight_layout()
- plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
- plt.title("{}:{}".format(name,label[i].item()))
- plt.xticks([])
- plt.xticks([])
- plt.show()
- def one_hot(label,depth = 10):
- out = torch.zeros(label.size(0),depth)
- idx = torch.LongTensor(label).view(-1,1)
- out.scatter_(dim=1,index=idx,value=1)
- return out
- # 一次加载多少图片
- batch_size = 512
- # step1. load dataset 数据加载
- train_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MINST('mnist_data',train=True,download=True,
- transform=torchvision.transforms.Compose([
- torchvision.transfroms.ToTensor(),
- torchvision.transfroms.Normalize(
- (0.1307,),(0.3081,))
- ])),
- batch_size=batch_size,shuffle=True)
- test_loader = torch.utils.data.DataLoader(
- torchvision.datasets.MINST('mnist_data/',train=False,download=True,
- transform=torchvision.transforms.Compose([
- torchvision.transfroms.ToTensor(),
- torchvision.transfroms.Normalize(
- (0.1307,),(0.3081,))
- ])),
- batch_size=batch_size,shuffle=False)
- # 网络创建
- class.NET(nn.Module):
- def __init__(self):
- super.NET,self).__init__()
- #xw+b
- self.fc1 = nn.Linear(28*28,256)
- self.fc2 = nn.Linear(256,64)
- self.fc3 = nn.Linear(64,10)
- def forward(self,x):
- # x:[batch_size,1,28,28]
- # h1 = relu(xw1+b1)
- x = F.relu(self.fc1(x))
- # h1 = relu(h1w2+b2)
- x = F.relu(self.fc2(x))
- # h3 = h2w3+b3
- x = self.fc3(x)
- return x
- net = Net()
- # [w1,b1,w2,b1,w3,b3]
- optimizer = optim.SGD.NET.parameters(),lr=0.01,momentum=0.9)
- train_loss = []
- # 训练
- for epoch in range(3):
- for batch_idx,(x,y) in enumerate(train_loader):
- # x: [b,1,28,28], y:[512]
- # [b,1,28,28]-->[b,feature]
- x = x.view(x.size(0),28*28)
- # --> [b,10]
- out = net(x)
- # --> [b,10]
- y_onehot = one_hot(y)
- # loss = mse(out,y_onehot)
- loss = F.mse_loss(out,y_onehot)
- # 清零梯度
- optimizer.zero_grad()
- # 计算梯度
- loss.backward()
- #w' = w - lr*grad 更新梯度
- optimizer.step()
- train_loss.append(loss.item())
- if batch_idx % 10 == 0:
- print(epoch,batch_idx,loss.item())
- plot_curve(train_loss)
- # 得到一个比较好的 [w1,b1,w2,b1,w3,b3]
- # 验证准确率
- total_correct = 0
- for x,y in test_loader"
- x = x.view(x.size(0),28*28)
- out = net(x)
- # out: [b,10] --> pred: [b]
- pred = out.argmax(dim = 1)
- correct = pred.eq(y).sum().float().item()
- total_correct += correct
- total_num = len(test_loader.dataset)
- acc = total_correct / total_num
- print('test acc:',acc)
- # 直观显示验证
- x,y = next(iter(test_loader))
- out = net(x.view(x.size(0),28*28))
- pred = out.argmax(dim = 1)
- plot_image(x,pred,'test')
来源: http://www.bubuko.com/infodetail-3415160.html