- """
- 自动编码的核心就是各种全连接的组合, 它是一种无监督的形式, 因为他的标签是自己.
- """
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- import torch.utils.data as Data
- import torchvision
- import matplotlib.pyplot as plt
- from mpl_toolkits.mplot3d import Axes3D
- from matplotlib import cm
- import numpy as np
- # 超参数
- EPOCH = 10
- BATCH_SIZE = 64
- LR = 0.005
- DOWNLOAD_MNIST = False
- N_TEST_IMG = 5
- # Mnist 数据集
- train_data = torchvision.datasets.MNIST(
- root='./mnist/',
- train=True,
- transform=torchvision.transforms.ToTensor(),
- download=DOWNLOAD_MNIST,
- )
- print(train_data.train_data.size()) # (60000, 28, 28)
- print(train_data.train_labels.size()) # (60000)
- # 显示出一个例子
- plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
- plt.title('%i' % train_data.train_labels[2])
- plt.show()
- # 将数据集分为多批数据
- train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
- # 搭建自编码网络框架
- class AutoEncoder(nn.Module):
- def __init__(self):
- super(AutoEncoder, self).__init__()
- self.encoder = nn.Sequential(
- nn.Linear(28*28, 128),
- nn.Tanh(),
- nn.Linear(128, 64),
- nn.Tanh(),
- nn.Linear(64, 12),
- nn.Tanh(),
- nn.Linear(12, 3),
- )
- self.decoder = nn.Sequential(
- nn.Linear(3, 12),
- nn.Tanh(),
- nn.Linear(12, 64),
- nn.Tanh(),
- nn.Linear(64, 128),
- nn.Tanh(),
- nn.Linear(128, 28*28),
- nn.Sigmoid(), # 将输出结果压缩到 0 到 1 之间, 因为 train_data 的数据在 0 到 1 之间
- )
- def forward(self, x):
- encoded = self.encoder(x)
- decoded = self.decoder(encoded)
- return encoded, decoded
- autoencoder = AutoEncoder()
- optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
- loss_func = nn.MSELoss()
- # initialize figure
- f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
- plt.ion() # 设置为实时打印
- # 第一行是原始图片
- view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.)
- for i in range(N_TEST_IMG):
- a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())
- for epoch in range(EPOCH):
- for step, (x, y) in enumerate(train_loader):
- b_x = Variable(x.view(-1, 28*28))
- b_y = Variable(x.view(-1, 28*28))
- encoded, decoded = autoencoder(b_x)
- loss = loss_func(decoded, b_y)
- optimizer.zero_grad() # 将上一部的梯度清零
- loss.backward() # 反向传播, 计算梯度
- optimizer.step() # 优化网络中的各个参数
- if step % 100 == 0:
- print('Epoch:', epoch, '| train loss: %.4f' % loss.data[0])
- # 第二行画出解码后的图片
- _, decoded_data = autoencoder(view_data)
- for i in range(N_TEST_IMG):
- a[1][i].clear()
- a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')
- a[1][i].set_xticks(()); a[1][i].set_yticks(())
- plt.draw(); plt.pause(0.05)
- plt.ioff()
- plt.show()
- # 可视化三维图
- view_data = Variable(train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.)
- encoded_data, _ = autoencoder(view_data)
- fig = plt.figure(2); ax = Axes3D(fig)
- X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
- values = train_data.train_labels[:200].numpy()
- for x, y, z, s in zip(X, Y, Z, values):
- c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)
- ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max())
- plt.show()
来源: http://www.bubuko.com/infodetail-2942118.html