本文原作者: 于洋, 经授权后发布.
1. 开篇
通常, 我们在使用 Tensorflow 低级 API 编程时(非 Eager 模式), 一般有下面三个步骤:
使用 tensorflow python 侧的 API 构建图. 图通常包括了两部分: 正向计算图和反向计算图;
构建的关键字是: 新建的 tf.Operation(节点)和 tf.Tensor(边)对象并将它们添加到 tf.Graph 实例中. 例如, 典型添加 op 操作就是 tf.matmul.
创建 tf.Session 会话;
此步骤的关键字是: 创建默认本地会话
with tf.Session() as sess:
, 创建分布式会话
with tf.Session("grpc://example.org:2222"):
在 tf.Session 会话中, 初始化全局变量, 并批量运行图.
此步骤的关键语句是: sess.run(init_op), sess.run(train_op)
参考链接: 图和会话 https://www.tensorflow.org/guide/graphs?hl=zh_cn ,
线性回归例子 https://www.youtube.com/watch?v=Xiab2JhwzYY
众所周知, tensorflow 使用支持多种前端语言(python,JS,swift,go 等), 执行引擎为 C/C++ 后端实现.
那么, 在上述三个步骤中, 当用户 python 构建图, 以及运行的图的时候. C/C++ 后端有在执行哪些工作呢?
按照对应的三个步骤, 我们做如下拆解:
python 在构建图的过程中, 也是 C/C++ 构造图的过程.
即 python 在新增的 tf.Operation(节点)和 tf.Tensor(边)的同时, C/C++ 的后端也生成对应的节点和边, 从而构造后端的图.
图创建好后, python 调用 tf.Session 语句, C/C++ 端会根据参数创建对应本地 Session 运行图, 或者分布式 Session 运行图.
通过 sess.run 触发一次图的正向计算, 以及反向计算.
本次分享的设计模式, 就是在上述第二阶段时: 创建本地 session 和分布式 session 时, tensorflow 是怎样利用抽象工厂设计模式的?
2. 抽象工厂设计模式(Abstract Factory)
在《设计模式》中描述的 23 设计模式, 分为三类: 创建型, 结构型, 行为型. 其中, 抽象工厂设计模式属于创建性设计模式. 即是解决对象的创建需求. 关于抽象工厂模式我的理解是这样的:
调用者有创建不同对象的需求(对象有一定相似性, 例如轿车, 卡车), 调用者无需关注具体的实现类, 而是通过抽象类定义的接口, 就能创造不同对象.
当然, 个人抽象理解和描述还是很难理解的. 我们根据 GOF 书中, 抽象工厂的模式结构图 (图需要从右上角看起) 在来理解一下:
调用者 (Client) 有创建对象 ProductA1 或 ProductA2 的需求,
但是 Client 类没有直接调用实现类 CreateProductA1,CreateProductA2.
而是通过抽象工厂 AbstractFactory 的接口创建了不同的对象(即: 创建对象 ProductA1 或 ProductA2).
[ 抽象工厂的模式结构图 - 《设计模式》58 页 ]
有了上面粗浅的理解后, 我们看一下 tensorflow 是如何使用抽象工厂模式, 创建本地 session 和分布式 session?
首先, 我们看一下 python 创建 Session 调用栈:
- -> tf_session.TF_NewSessionRef
- -> TF_NewSession
- -> NewSession
NewSession 的代码如下:
- Status NewSession(const SessionOptions& options, Session** out_session) {
- SessionFactory* factory;
- Status s = SessionFactory::GetFactory(options, &factory);
- if (!s.ok()) {
- *out_session = nullptr;
- LOG(ERROR) <<s;
- return s;
- }
- s = factory->NewSession(options, out_session);
- if (!s.ok()) {
- *out_session = nullptr;
- }
- return s;
- }
代码很枯燥, 我们看一下上述代码的时序图(以创建 DirectSessione 为例).
上述代码对应着时序图的阶段 2 和阶段 3. 其中:
阶段 2 对应代码
- SessionFactory::GetFactory(options, &factory);
- ,
阶段 3 对应代码
factory->NewSession(options, out_session);
[ NewSession 的时序图 ]
看到这里, 我们温习一下抽象工厂的理解:
Client(NewSession)有创建 GrpcSession 或者 DirectSession 的需求;
但是, Client 没有直接调用 new DirectSession 或者 new GrpcSession 创建;
而是, 通过调用抽象工厂 (SessionFactory) 接口 GetFactory 找到
DirectSessionFactory
. 最终通过
DirectSessionFactory->NewSession
创建;
最终返回实例为 Session 型(多态可以到 GrpcSesion 或者 DirectSession 对象).
值得说明的是: Client 在整个过程中, 并不清楚里面不同的 Factory(GrpcSessionFactory 和 DirectSessionFactory), 也不清楚不同的 Session 类型(GrpcSession 和 DirectSession).
最后, 参考抽象工厂结构图, 大致画了如下 Session 的创建环节, 大家可以在回味一下该设计模式(图也是从右上角看起):
[ 抽象工厂模式创建 Session ]
至此, 创建 Session 的主题框架已经大致梳理出来了. 但是, 上面的时序图中的阶段 1 一直还没有说明吧?
好, 这部分涉及了单件设计模式.
后记: 按照下面的定义, 上述创建 Session 的模式 (因为只创建了一种 Session 产品) 是不是叫 "工厂方法" 会好一点?
简单工厂: 一个工厂类, 一个产品抽象类.
工厂方法: 多个工厂类, 一个产品抽象类.
抽象工厂: 多个工厂类, 多个产品抽象类.
说一下个人理解, tensorflow 在设计这段代码的时候, 做了很高程度的抽象, 具备完成多个产品抽象的能力. 我这里姑且认为应用的是抽象工厂模式.
大家也可以按照 "工厂方法" 模式理解上述代码, 宗旨是: 希望大家在学习 tensorflow 代码的过程, 能了解里面蕴含的设计模式.
3. 单件设计模式(Singleton)
NewSession 中有这样的代码, 不知道大家是否有注意到 SessionFactory::GetFactory(options, &factory);? 这段代码的含义也就是根据传递的 options 信息, 选择是 DirectSessionFactory 还是分布式 GrpcSessionFactory.
但是, 大家在看时候, 有没有这样的疑问: 不同的 SessionFactory 的是什么时候写入到 SessionFactory map 中的? 何况 tensorflow 这种没有 main 函数的程序? 这个问题曾经一直很困扰我, 在 gdb debug 后, 我发现了下面的小 trick.
诀窍在这行代码中 static DirectSessionRegistrar registrar;.
SessionFactory map 初始化的能量蕴含在这个 static 变量的构造函数. 下面的流程图揭示所有的秘密. 结合代码, 从图的左下角看起(下面的代码对应上面 NewSession 的时序图).
和全局变量一样, static 变量一直存储在程序的静态存储区. 当程序初始化 static 变量时, 通过 DirectSessionRegistrar 和 GrpcSessionFactory 的构造函数完成初始化, 将不同的 SessionFactory(工厂对象)写入到 SessionFactory map 中.
[ SessionFactory map 的初始化过程 ]
囧~~~, 扯了半天的代码和流程, 貌似一点都没有提及单件设计模式. 其实, 单件设计模式在还是比较简单的. GOF 中定义如下:
保证一个类仅有一个实例, 并提供一个访问它的全局访问点.
tensorflow 这里使用了单例中一种更灵活的模式: 单件注册表, 也就是使用的一个 Singleton 类的集合(从上图看到存储结构是 std::unordered_map),Singleton 类通过一个注册接口将自己的单件实例注册到集合中. 而这里的 tensorflow 是通过 DirectSessionRegistrar 和 GrpcSessionFactory 构造函数中的 SessionFactory::Register 接口完成注册.
4. 进阶
其实, 在 tensorflow 中, 上述模式还有很多资源管理的场景中使用. 如下给出代码指引, 感兴趣的同学可自行学习:
- DeviceFactory //Tensorflow 设备管理的代码
- ExecutorFactory //Tensorflow 图执行单元的代码
5. 参考
代码参考: tensorflow v1.12.0
画图: https://www.draw.io/?mode=github
更多优质内容请关注官方微信公众号
长按 / 识别关注我们
来源: https://www.qcloud.com/developer/article/1476567