前言
首先, 如果你现在已经很熟悉 tf.data+estimator 了, 可以把文章 x 掉了╮(~▽~"")╭
但是! 如果现在还是在进行 session.run(..)的话! 尤其是苦恼于 GPU 显存都塞满了利用率却上不去的童鞋, 这篇文章或许可以给你打开新世界的大门噢(~~)
如果发现经过一系列改良后训练效率大大提高了, 记得回来给小夕发小红包(~~)
不过, 这并不是一篇怒贴一堆代码, 言 (三) 简(言)意 (两) 赅(语)就结束的 CSDN 文风的文章... 所以伸手党们也可以 X 掉了╮(~▽~"")╭
缘起
很早很早之前, 在小夕刚接触 tensorflow 和使用 GPU 加速计算的时候, 就产生过一个疑惑. 为什么显卡的显存都快满了, GPU 利用率还显示这么低呢? 好浪费呀, 但是又无可奈何. 当时 GPU 利用率 100% 的情况基本是仅存于一块显卡塞 4,5 个不费显存的小任务的情况.
在比较极端的情况下, 甚至 GPU 的利用率会降到 10% 以下, 就像这样:
而大部分情况下写出来的代码 train 起来后是这样的:
可以看到, 虽然显卡的显存都塞满了, 但是显卡功率 (最左边那一栏, 114W 和 69W) 和利用率 (最右边那一栏, 35% 和 38%) 却远远没有达到极限. 大部分人的想法是, 算了算了这不重要, 我去做实验了再见[wei 笑]
然而! 如果你在做大型实验, train 一次跑几天呢? 这个细节会极大的影响你的实验效率和 DDL 到来前的实验次数! 想一下, 完全一样的 model 和设置, 你的代码要 train 一周, 然而隔壁老王只需要 train 三天╮(~▽~"")╭
路人甲: 我有 256 张显卡
小夕: 好了这篇文章你可以 X 掉了
那么, 我们有没有可能一直这样呢:
是不是这功率和利用率看起来不可思议! 不要怀疑这是 PS 的图! 这只是小夕的日常截图! tricks 用的好 GPU 利用率掉不下来 99%, 然鹅代码写的足够蠢, 也可以上不去 5%!
那么问题来了, 到底是什么导致的这个差异呢?
不要急, 我们来放大一下那些 gpu 利用率只有 30% 几的代码在训练时的 gpu 利用率的变化情况(好像句子有点长
watch -n 0.1 nvidia-smi
ps:(可能掉帧太严重了看着不连贯╮(~▽~"")╭, 建议在自己的机器上试一下, 会直观的多~)
看! 是不是一下子就发现问题啦? 可以看到, 其实 gpu 利用率并不是一直在比较低的水平, 而是很有规律的周期性的从 0 涨到接近 100 再跌到 0, 再重新涨到 100 再跌回 0. 如果同时开着打印日志的窗口, 你就会发现这个周期恰好跟每个训练 step 的时长一致! 也就是说, 在每个 step, 其实有一些时间并没有花在 GPU 里, 那当然就是花在 CPU 里啦.
那在 CPU 里干什么的呢? 当然就是 load 下一个 batch, 预处理这个 batch 以及在 gpu 上跑出结果后打印日志, 后处理, 写 summary 甚至保存模型等, 这一系列的花销都要靠 CPU 去完成. 回顾一下我们常写的代码:
- create_graph()
- create_model_saver()
- create_summary_writer()
- create_session()
- do_init()
- for i in range(num_train_steps):
- load_batch(...) # CPU
- preprocess(...) # CPU
- feed_dict = {...} # CPU
- fetch_list = [...] # CPU
- buf = session.run(fetch_list, feed_dict) # gpu
- postprocess(buf) # CPU
- print(...) # CPU
- if i % x == 0:
- summary_writer.write(...) # CPU
- if i % xx == 0:
- model_saver.save(...) # CPU
看, 尤其是 preprocess(...)任务比较重的话就容易导致代码在 CPU 里也要跑好一段时间, gpu 利用率自然就会上不去而且呈现周期性变化啦.
那么有没有什么办法降低 CPU 时间, 提高 gpu 时间呢?
一个很自 (愚) 然(蠢)的想法就是把一切训练代码都用 tf 的 API 重写不就好啦, 甚至最外层的那个 for i in range(num_train_steps)其实都可以用 tf.while_loop 重写呀. 嗯, 小夕还真的这么尝试过, 然后发现
TF API 这特喵的都是些什么鬼! 各种跟 numpy 和 python 内置函数重名却行为不一致是什么鬼! 卧槽这个 API 少了个参数我该怎么办? python 里就一行代码就能搞定的事情我为什么写了几十行??
所以除了函数式编程的大牛, 小夕极力的不建议重蹈覆辙! 尤其是我们这些遇到汇编会哭, 看到 Lisp 会崩溃的 90 后小仙女!
所以没办法把整个 train loop 都描述进计算图了?
别怕别怕, 好在后来其实 tensorflow 已经封装了一个特别好 (多) 用(坑)的上层 API 来把整个 train loop 都能轻松的封装在计算图中, 从而实现超级高的 GPU 利用率和训练效率!
Estimator
不用管它为啥叫 Estimator, 只需要知道, 它把我们刚才想做的事情基本都给封装好了就行. 把刚才的那个经典的写法搬过来
- create_model()
- create_model_saver()
- create_summary_writer()
- create_session()
- do_init()
- for i in range(num_train_steps):
- load_batch(...) # CPU
- preprocess(...) # CPU
- feed_dict = {...} # CPU
- fetch_list = [...] # CPU
- buf = session.run(fetch_list, feed_dict) # gpu
- postprocess(buf) # CPU
- print(...) # CPU
- if i % x == 0:
- summary_writer.write(...) # CPU
- if i % xx == 0:
- model_saver.save(...) # CPU
1-5 行在 estimator 中都封装好啦, 你只需要把相关配置塞进 estimator 的 RunConfig 就可以啦~
7-9 行也封装好啦, 你只需要把数据集载入和预处理的相关代码的函数塞给 estimator.train 的 input_fn~
第 10 行也封装好啦, 你只需要把要 fetch 的 loss,train_op 丢进 estimator 的 EstimatorSpec~
第 11 行也封装好啦, 你只需要把描述模型计算图的函数塞给 estimator 的 model_fn~
第 12-13 行不用操心细节了, global_step 和 loss 自动完成了, 剩下的丢给 tf.Print 和 LoggingTensorHook 吧~
第 14-17 行不用你写了, 自动完成了
╮(╯▽╰)╭
经过这么一顿折腾, 我们发现 GPU 利用率大大提高啦~直逼 80% 甚至 90%. 那么还有没有可以压榨的空间呢?
其实这时仔细一分析就会发现虽然 estimator 把大部分的代码写进计算图里了, 但是从数据的载入和预处理依然是在 CPU 里串行进行呀, 而且比如一个 batch 有 128 个样本, 那么 estimaor 内部在 run 每个 step 的时候还是要等着这 128 个样本串行的处理完才行. 这显然就是最后的瓶颈啦! 有没有办法消除掉呢?. 当然有, 那就是
tf.data
TF 的 dataset API 可以说让人又爱又恨了, 它确实看似提供了一种把整个预处理都搬进计算图进行并行化处理的途径, 但是! 如果你真的完全用 tensorflow API 来做复杂的预处理的话, 真的会让人疯掉的 QAQ 因此, 这里在用 tf.data 之前, 小夕极力的建议先把数据集尽可能的 transform 成预处理后的样子, 包括做分词, 做截断, 做 word2id 等, 不过 padding 和 input_mask 可以留在 TF 里面做, 毕竟都只需要一行.
那做完这些预处理后, 数据该怎么存储会更方便后续的读取和处理呢? 最最最建议的方式还是使用 tf.records 来存储, 磁盘, 内存的存储和 IO 效率都会相比传统方式更快一些, x 和 y 也不用分开了. 当然这样的唯一的坏处就是不能直接打开看数据集╮(~▽~"")╭毕竟数据集被做成了二进制文件.
但是实在比较懒不想用 tf.record 的话, 那么小夕极力建议把 x 和 y 分开存储, 并且尽量让 tf.data 在读取数据的时候做完上面的那些必要的预处理, 以避开难用的字符串基础操作 API 并且减轻训练时的 CPU 和内存压力.
tf.data 还有一个很大的好处就是可以很天然的支持以 streaming 的方式读取数据, 这样在面对大数据集时就不会发生数据 load 完后发现显卡被占的尴尬事件了╮(~▽~"")╭
好像讲了这么久, 还是没讲怎么用 tf.data 加速 QAQ, 来来来进入正题啦.
想想哈, 没用 tf.data 的时候, 我们写出来的代码实际跑起来就是这个样子的:
这也是文章开头小夕解释的为什么 gpu 利用率上不去并且周期性变化的重要原因. 那么我们可以不可以消除 idle, 像下面这样让 prepare 和 train 的过程并行进行呢?
当然可以! 那就是
prefetch
从 prefetch 的意思就可以理解, 那就是预先获取下一个 step 要 load 的 batch. 使用 tf.data 里面的叫做 prefetch 的神奇 API 就可以轻松完成啦, 这个 API 里的参数 buffer_size 就是讲的是额外的 fetch 多少份, 比如 buffer_size=1, 然后我们要 prefetch 的是 batch 的话, 那么模型每次 prepare 完一个 batch 后, 就会自动再额外的 prepare 一个 batch, 这样下一个 train step 到来的时候就可以直接从内存中取走这个事先 prepare 好的 batch 啦.(详情见后面)
等下, 看上图的话, 有木有发现, 如果 prepare 一个 batch 耗时很短的话确实两全齐美, 但是如果耗时比较久, 尤其一下子 prefetch 好几个 batch 的话, 一旦 prepare 的用时超过了 train 一个 step 的用时, 那么每个 train step 的性能就会受限于 prepare 的效率啦. 放大一下这个问题的话如下图所示
看, prepare 用时太久反而会导致 train 完一个 step 后 gpu 空闲了(虽然其实下个 step 的 batch 可能已经 prepare 好了)
那么能不能确保 prepare 阶段的用时小于 train 阶段的用时呢?
parallel mapping
一个很简单的想法当然就是让样本并行处理啦~如果 batch size 是 128,prefetch size=1, 那么准备一个 batch 要串行的跑 128*2=256 次的预处理, 但是如果我们开 4 个线程去跑, 是不是就看起来快多啦. 幸运的是我们也不用自己手撸多线程了, tf.data.Dataset 在 map(预处理)函数里有一个参数 num_parallel_calls, 给这个参数赋值就可以并行 parse 啦. 如图,
这样的话只要 prefetch 的 buffer_size 和 map 的 num_parrellel_calls 取得合适, 基本就可以实现不间断的 train 啦, 也就是几乎达到 100% 的 GPU 利用率!
好啦, 思想明白了, 代码就容易理解啦. 不使用 tf.record, 直接从预处理好的纯文本格式的数据集 load 数据时的典型过程如下
- def build_input(..):
- x = tf.data.XXDataset(..)
- x = x.map(..., num_parallel_calls=N) # parellel
- y = tf.data.XXDataset(..)
- y = y.map(..., num_parallel_calls=N)
- dataset = tf.data.Dataset.zip((x, y))
- dataset = dataset.repeat(num_epochs)
- if is_train:
- dataset = dataset.shuffle(..)
- dataset = dataset.batch(batch_size)
- dataset = dataset.prefetch(buffer_size=1) # prefetch
- iterator = dataset.make_xx_iterator()
- return iterator.get_next()
当然, 如果用上 tf.record 后, 就不用分别从 x 和 y 俩文件中读数据啦, 感兴趣的童鞋可自行去了解一下.
补充福利
当然, 刚从传统的代码迁移到 tf.data+estimator 的时候可能会不太适应, 最主要的还是 debug 的方式, 不能像之前一样直接 session.run(debug_tensor)了, 那怎么办呢?
一般来说我们打印 tensor 有两种情况, 一种是计算图出错时需要打印一次或几次来定位问题, 一种是像 global_step,loss 等需要周期性 check. 对于这两种情况, 之前是习惯 session.run 的时候把要打印的 tensor 也 run 出来, 而现在这两种情况可以区分对待啦.
对于第一种, 小夕感觉最高效的还是直接在计算图里插 tf.Print(..), 使用非常方便, debug 能力很强大! 如果打印还需要配合 global step, 加一条 tf.cond 就搞定啦. 对于第二种, 其实 global step 和 loss 的话 estimator 默认就会打印出来, 如果是其他需要周期性打印的 tensor, 那么就用 tf.train.LoggingTensorHook 包装一下然后丢进 estimator.train 里吧~习惯之后竟然还感觉挺方便的 m(__)m
最后, 愿天下没有空闲的显卡
关注[OpenCV 与 AI 深度学习]
长按或者扫描下面二维码即可关注
来源: https://www.cnblogs.com/stq054188/p/11832001.html