05 EM 算法 - 高斯混合模型 - GMM https://www.jianshu.com/p/c5b203ce7f6a
多元正态分布 - multivariate_normal API 参考链接:
- https://docs.scipy.org/doc/numpy-dev/genindex.html
- http://scipy.github.io/devdocs/
- http://scipy.github.io/devdocs/stats.html
很多我们运用到的函数都在 scipy 库中, 比如 numpy 是用一些 Python 基础语法, 然后加上 scipy 函数来写出来的.
常规操作:
- import numpy as np
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- from mpl_toolkits.mplot3d import Axes3D
- from scipy.stats import multivariate_normal# 多元正态分布
- from sklearn.mixture import GaussianMixture#GMM Gaussian Mixture Model
- from sklearn.metrics.pairwise import pairwise_distances_argmin
- # 解决中文显示问题
- mpl.rcParams['font.sans-serif'] = [u'SimHei']
- mpl.rcParams['axes.unicode_minus'] = False
- # 设置在 jupyter 中 matplotlib 的显示情况 (默认 inline 是内嵌显示, 通过设置为 tk 表示不内嵌显示)
- %matplotlib tk
高斯混合模型的库: sklearn.mixture.GaussianMixture
__PS:__在 0.18 版本以前是__sklearn.mixture.GMM__, 两者的参数基本类型, 这里主要介绍__GaussianMixture__的相关参数.
属性参数:
n_components: 混合组合的个数, 默认为 1, 可以理解为聚类 / 分类数量.$color{red}{有几个独立的高斯分布 k - 工作中有价值去调的参数就这一个.}$
covariance_type: $color{red}{第 k 分类的条件下样本的协方差 - ∑k}$ 给定协方差的类型, 可选: __full,tied,diag,spherical__.
默认为 full;
__full:__每个组件都有自己的公用的协防差矩阵. 即, 每一个协方差矩阵都不相等.$color{red}{ ∑1,∑2,...,∑k 都不相等.}$
__tied:__所有组件公用一个协方差矩阵. 即, 每一个协方差矩阵都相等.$color{red}{ ∑1 = ∑2 =,...,= ∑k}$
__diag:__每个组件都有自己的斜对角协方差矩阵.
__spherical:__每个组件都有相同的方差值;
__tol:__默认 1e-3, 收敛阈值, 如果在迭代过程中, 平均增益小于该值的时候, EM 算法结束.
__reg_covar:__协方差对角线上的非负正则化参数, 默认为 0 - 表示不用非负正则化.
max_iter: em 算法的最大迭代次数, 默认 100.
n_init: 默认值 1, 执行初始化操作数量, 该参数最好不要变动.
__init_params:__初始化权重值, 均值以及精度的方法, 参数可选:__kmeans,random__.
默认 kmeans;
__kmeans:__使用 kmeans 算法进行初始化操作.
__weights_init:__初始化权重列表, 如果没有给定, 那么使用__init_params__参数给定的方法来进行创建, 默认为 None.
__means_init:__初始化均值列表, 如果没有给定, 那么使用__init_params__参数给定的方法来进行创建, 默认为 None.
precisions_init: 初始化精度列表, 如果没有给定, 那么使用__init_params__参数给定的方法来进行创建, 默认为 None.
__warn_stat:__默认为 False, 当该值为 true 的时候, 在类似问题被多次训练的时候, 可以加快收敛速度.
1, 使用 scikit 携带的 EM 算法或者自己实现的 EM 算法
- def trainModel(style, x):
- if style == 'sklearn':
- print("sklearn")
- # 对象创建
- g = GaussianMixture(n_components=2, covariance_type='full',
- tol=1e-6, max_iter=1000, init_params='kmeans')
- # 模型训练
- g.fit(x)
- # 效果输出
- print('类别概率:\t', g.weights_[0])
- print('均值:\n', g.means_, '\n')
- print('方差:\n', g.covariances_, '\n')
- print('似然函数的值:\n', g.lower_bound_)
- mu1, mu2 = g.means_
- sigma1, sigma2 = g.covariances_
- # 返回数据
- return (mu1, mu2, sigma1, sigma2)
- else:
- ## 自己实现一个 EM 算法
- ## 迭代 100 次
- num_iter = 100
- n, d = data.shape
- # 初始化均值和方差正定矩阵 (sigma 叫做协方差矩阵)
- mu1 = data.min(axis=0)
- mu2 = data.max(axis=0)
- sigma1 = np.identity(d)
- sigma2 = np.identity(d)
- pi = 0.5 # 属于第一类高斯分布的概率
- print("随机初始的期望为:")
- print(mu1)
- print(mu2)
- print("随机初始的方差为:")
- print(sigma1)
- print(sigma2)
- print("随机初始的π为:")
- print([pi, 1-pi]) #1-pi 就是属于第二类高斯分布的概率
- # 实现 EM 算法
- for i in range(num_iter):
- # E Step
- # 1. 计算获得多元高斯分布的概率密度函数
- norm1 = multivariate_normal(mu1, sigma1)
- norm2 = multivariate_normal(mu2, sigma2)
- # 2. 计算概率值
- tau1 = pi * norm1.PDF(data)
- tau2 = (1 - pi) * norm2.PDF(data)
- # 3. 概率值均一化 (即公式中的 w)
- gamma = tau1 / (tau1 + tau2)
- # M Step
- # 1. 计算更新后的均值
- mu1 = np.dot(gamma, data) / np.sum(gamma)
- mu2 = np.dot((1 - gamma), data) / np.sum((1 - gamma))
- # 2. 计算更新后的方差
- sigma1 = np.dot(gamma * (data - mu1).T, data - mu1) / np.sum(gamma)
- sigma2 = np.dot((1 - gamma) * (data - mu2).T, data - mu2) / np.sum(1 - gamma)
- # 3. 计算更新后的π值
- pi = np.sum(gamma) / n
- # 输出信息
- j = i + 1
- if j % 10 == 0:
- print (j, ":\t", mu1, mu2)
- # 效果输出
- print ('类别概率:\t', pi)
- print ('均值:\t', mu1, mu2)
- print ('方差:\n', sigma1, '\n\n', sigma2, '\n')
- # 返回结果
- return (mu1, mu2, sigma1, sigma2)
2, 创建模拟数据
- # 创建模拟数据 (3 维数据)
- np.random.seed(28)
- N = 500
- M = 250
- ## 根据给定的均值和协方差矩阵构建数据
- mean1 = (0, 0, 0)
- cov1 = np.diag((1, 2, 3))
- ## 产生 400 条数据
- data1 = np.random.multivariate_normal(mean1, cov1, N)
- ## 产生一个数据分布不均衡的数据集, 100 条
- mean2 = (2, 2, 1)
- cov2 = np.array(((3, 1, 0), (1, 3, 0), (0, 0, 3)))
- data2 = np.random.multivariate_normal(mean2, cov2, M)
- ## 合并 data1 和 data2 这两个数据集
- data = np.vstack((data1, data2))
- ## 产生数据对应的 y 值
- y1 = np.array([True] * N + [False] * M)
- y2 = ~y1
- print(y1)
- print('---------')
- print(y2)
输出:
- [ True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False]
- ---------
- [False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False False False False False
- False False False False False False False False True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True True True True True True True
- True True True True True True]
3, 预测结果 (得到概率密度值)
- style = 'sklearn'
- style = 'self'
- mu1, mu2, sigma1, sigma2 = trainModel(style, data)
预测分类 (根据均值和方差对原始数据进行概率密度的推测)
- norm1 = multivariate_normal(mu1, sigma1)
- norm2 = multivariate_normal(mu2, sigma2)
- tau1 = norm1.PDF(data)
- tau2 = norm2.PDF(data)
随机初始的期望为:
- [-3.2408628 -3.85600655 -5.36300731]
- [8.18151162 6.20669356 5.719554 ]
随机初始的方差为:
- [[1. 0. 0.]
- [0. 1. 0.]
- [0. 0. 1.]]
- [[1. 0. 0.]
- [0. 1. 0.]
- [0. 0. 1.]]
随机初始的π为:
- [0.5, 0.5]
- 10 : [0.15631385 0.06809082 0.04857579] [2.44502406 2.45578419 0.9091761 ]
- 20 : [0.14397974 0.04785923 0.04506974] [2.28240941 2.30685014 0.84484619]
- 30 : [0.14088163 0.02935465 0.03209345] [2.15917602 2.21594906 0.82923206]
- 40 : [0.13934104 0.01693114 0.0228729 ] [2.08010249 2.1564362 0.8189457 ]
- 50 : [0.13836058 0.00937934 0.01770358] [2.03246294 2.11908375 0.81057717]
- 60 : [0.13774293 0.00491933 0.01495515] [2.0042425 2.09629406 0.80446921]
- 70 : [0.13736483 0.00229644 0.01348018] [1.98756296 2.08257617 0.80038644]
- 80 : [0.13713691 0.00075078 0.01266793] [1.97769407 2.07437102 0.79779292]
- 90 : [ 0.13700051 -0.00016251 0.01220949] [1.97184694 2.06947843 0.79619181]
- 100 : [ 0.13691916 -0.00070326 0.0119459 ] [1.96837906 2.06656572 0.79521918]
类别概率: 0.6899093466771541
均值: [ 0.13691916 -0.00070326 0.0119459 ] [1.96837906 2.06656572 0.79521918]
方差:
- [[ 0.95205447 0.1172476 -0.03011033]
- [ 0.1172476 2.17736238 0.01510168]
- [-0.03011033 0.01510168 2.69115809]]
- [[ 3.94127407 1.31328462 -0.26576838]
- [ 1.31328462 3.4491304 0.12676854]
- [-0.26576838 0.12676854 2.94238901]]
4, 计算均值的距离, 然后根据距离得到分类情况
- dist = pairwise_distances_argmin([mean1, mean2], [mu1, mu2], metric='euclidean')
- print ("距离:", dist)
- if dist[0] == 0:
- c1 = tau1> tau2
- else:
- c1 = tau1 < tau2
- c2 = ~c1
计算准备率
- acc = np.mean(y1 == c1)
- print (u'准确率:%.2f%%' % (100*acc))
距离: [0 1]
准确率: 85.07%
5, 画图
- fig = plt.figure(figsize=(12, 6), facecolor='w')
- ## 添加一个子图, 设置为 3d 的
- ax = fig.add_subplot(121, projection='3d')
- ## 点图
- ax.scatter(data[y1, 0], data[y1, 1], data[y1, 2], c='r', s=30, marker='o', depthshade=True)
- ax.scatter(data[y2, 0], data[y2, 1], data[y2, 2], c='g', s=30, marker='^', depthshade=True)
- ## 标签
- ax.set_xlabel('X')
- ax.set_ylabel('Y')
- ax.set_zlabel('Z')
- ## 标题
- ax.set_title(u'原始数据', fontsize=16)
- ## 添加一个子图, 设置为 3d
- ax = fig.add_subplot(122, projection='3d')
- # 画点
- ax.scatter(data[c1, 0], data[c1, 1], data[c1, 2], c='r', s=30, marker='o', depthshade=True)
- ax.scatter(data[c2, 0], data[c2, 1], data[c2, 2], c='g', s=30, marker='^', depthshade=True)
- # 设置标签
- ax.set_xlabel('X')
- ax.set_ylabel('Y')
- ax.set_zlabel('Z')
- # 设置标题
- ax.set_title(u'EM 算法分类', fontsize=16)
- # 设置总标题
- plt.suptitle(u'EM 算法的实现, 准备率:%.2f%%' % (acc * 100), fontsize=20)
- plt.subplots_adjust(top=0.90)
- plt.tight_layout()
- plt.show()
来源: https://yq.aliyun.com/articles/684415