前一节从宏观角度给大家介绍了 Spark ML 的设计框架 (链接: http://www.cnblogs.com/jicanghai/p/8570805.html), 本节我们将介绍, Spark ML 中, 机器学习问题从单机到分布式转换的核心方法
单机时代, 如果我们想解决一个机器学习的优化问题, 最重要的就是根据训练数据, 计算损失函数和梯度由于是单机环境, 什么都好说, 只要公式推导没错, 浮点数计算溢出问题解决好, 就好了但是, 当我们的训练数据量足够大, 大到单机根本存储不下的时候, 对分布式学习的需求就出现了比如电商数据, 动辄上亿的训练数据量, 单机望尘莫及, 只能求助于分布式计算
那么问题来了, 在分布式计算中, 怎样计算得到损失函数的值, 以及它的梯度值呢? 这就涉及到 Spark ml 的一个核心, 用八个字概括就是, 模型集中, 计算分布具体来说, 比如我们要学习一个逻辑回归模型, 它的训练数据可能是存储在成百上千台服务器上, 但具体的模型, 只集中于一台服务器上每次迭代时, 我们现在训练数据所在的服务器上, 并行的计算出, 每个服务器包含的训练数据, 所对应的损失函数值和梯度值, 然后把这些信息集中在模型所在的机器上, 进行合并, 总结出所有训练数据的损失函数值和梯度值, 然后对所学习的参数进行迭代, 并把参数分发给拥有训练数据的服务器, 并进入下一个迭代循环, 直到模型收敛
如此看来, 分布式机器学习也没有什么特别的, 核心问题就在于, 怎样把每个服务器上计算的损失函数值和地图值集中到模型所在的服务器上, 除此之外, 跟单机的机器学习问题并没有什么不同
这一步, 在 Spark ML 中是如何实现的呢? 这里要隆重介绍一个函数, treeAggregate, 在我看来, 这个函数是从单机到分布式机器学习的核心, 理解了这个函数, 分布式机器学习问题, 就理解大半了
treeAggregate 函数主要做什么呢? 它负责把每一台服务器上的信息进行聚合, 然后汇总给模型所在的服务器拥有训练数据的服务器, 可能动辄成千上万, 这么多数据怎样聚合起来呢? 其实函数名字已经有暗示了, 它用的是树形聚合方法假设我们有 32 台服务器, 如果使用线性聚合, 也就是说, 1 跟 2 合并, 结果再跟 3 合并, 这样一共需要进行 31 次合并, 而且每次合并还不能并行进行, 因此 treeAggregate 采用的方法是, 把 32 个节点分配到一颗二叉树的 32 个叶子节点, 然后从叶子节点开始一层一层的聚合, 这样只需要 5 次聚合就可以了
具体的, 使用 treeAggregate 函数需要定义两种运算, 分别是 seqOp 和 combOp, 前者的作用是, 把一个训练样本加入已有的统计, 即对损失函数值和梯度进行更新, 后者的作用是, 把两个统计信息合并起来, 可以这样理解, 前者主要在单机上的统计计算时起作用, 后者主要是在不同服务器进行数据合并时起作用
有了这些核心概念, 就可以进入 optim 目录去一探究竟了, optim 目录是 Spark ML 跟优化相关内容的代码库, 它主要包含三部分, 一是 aggregator 目录, 二是 loss 目录, 三是根目录, 下面我们逐一介绍
aggregator 目录下存放的是, 聚合相关的代码我们知道在机器学习任务中, 不同的任务需要聚合的信息是不一样的这里就为我们实现了几个最基本的聚合操作其中, DifferentiableLossAggregator 是基类, 顾名思义, 实现了最基本的可微损失函数的聚合, 实际上的聚合操作都是由它的子类完成的, 基类中定义了通用的 merge 操作, 具体的 add 操作由各子类自己定义, 代码实现都比较直接, 就不一一介绍了, 感兴趣的朋友可以直接读源码
loss 目录下存放的是, 损失函数相关的代码其实, 最一般性的损失函数是在 breeze 库中定义的, 这个等我们在介绍 breeze 库的时候再细说 loss 目录下有两个文件, 一个是 DifferentiableRegularization.scala, 这里是把正则也当作一种损失, 主要包含 L2 正则, 另一个是 RDDLossFunction.scala, 这个就非常重要了, 它就是应用 treeAggregate 函数, 从单机的损失 + 梯度, 汇总到分布式版的损失 + 梯度的函数, 它主要应用了 aggregate 目录下的聚合类实现分布式的聚合运算
根目录下主要包含了几个优化问题的解法, 最基础的是 NormalEquationSolver.scala, 它主要描述了一个最小二乘的标准解法, 也就是正规方程的解法, 其次是 WeightedLeastSquares.scala, 它解决了一个带权值的最小二乘问题, 利用了正规方程解法, 最后是 IterativelyReweightedLeastSquares.scala, 这是在解逻辑斯蒂回归等一大类一般性线性回归问题中常用的 IRLS 算法, 利用了带权值的最小二乘解法
来源: https://www.cnblogs.com/jicanghai/p/8638245.html