DQN[1, 2] 是 Q-Learning 的一个扩展, 可以归类为改进价值函数的参数化表示 (使用神经网络来建模, 而非简单的线性非线性表达), 此外也更适用于 "大"(或无穷) 状态空间 -- 相比与基于表格型方法, 此外也简化了输入的表达形式(以 Atari game 为例)-- 使用连续 N 帧的 raw pixels 而非 handcrafted features(这里其实就是典型的深度学习套路了, 宛如 CNN 在 ImagNet 上的屠榜).
DQN 最早出现在 2013 年 [1], 也是我手写梯度传播的年份.
DQN[2] 使用如下方法更新模型参数:
这里分为两类模型参数 Target 和 Online , 下标 i 处理的不是很合理, 因为它们是独立不同步更新的, 准确的来说 Online 就是标准的深度学习梯度更新, 而 Target 只做周期性的拷贝 的 "更新", 这也是 Q-Learning 作为 off-policy 类方法的特点.
这里以 https://github.com/google/dopamine 里的 DQN https://github.com/openai/gym/wiki/CartPole-v0 为例, 解释其核心代码, CartPole-v0 是一个平衡类的游戏, 智能体需要根据环境的 Observation 作出对应的 Action, 环境给予 Reward 和新的 Observation:
截图自 https://github.com/openai/gym/wiki/CartPole-v0
- # 运行
- python -um dopamine.discrete_domains.train \
- --base_dir exp/dqn_cartpole \
- --gin_files dopamine/agents/dqn/configs/dqn_cartpole.gin
- # tensorboard 查看训练情况
- tensorboard --logdir exp
dopamine 是一个面向研究的框架, 没有分布式 RL 的实现, 它的实现就很简单没有太多封装, 它使用 https://github.com/google/gin-config 来管理模型训练.
DQN 核心代码(这里我们忽略模型网络结构, 专注在上图中公式的代码实现)
表达式 1
- # 位置 dopamine/agents/dqn/dqn_agent.py
- # 计算表达式 1
- def _build_target_q_op(self):
- """Build an op used as a target for the Q-value.
- Returns:
- target_q_op: An op calculating the Q-value.
- """
- # Get the maximum Q-value across the actions dimension.
- replay_next_qt_max = tf.reduce_max(
- self._replay_next_target_net_outputs.q_values, 1)
- # Calculate the Bellman target value.
- # Q_t = R_t + \gamma^N * Q'_t+1
- # where,
- # Q'_t+1 = \argmax_a Q(S_t+1, a)
- # (or) 0 if S_t is a terminal state,
- # and
- # N is the update horizon (by default, N=1).
- return self._replay.rewards + self.cumulative_gamma * replay_next_qt_max * (
- 1. - tf.cast(self._replay.terminals, tf.float32))
- # 获取目标函数也就是图片里完整的公式
- def _build_train_op(self):
- """Builds a training op.
- Returns:
- train_op: An op performing one step of training from replay data.
- """
- replay_action_one_hot = tf.one_hot(
- self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')
- replay_chosen_q = tf.reduce_sum(
- self._replay_net_outputs.q_values * replay_action_one_hot,
- axis=1,
- name='replay_chosen_q')
- # 这里获取 表达式 1
- # 使用 stop_gradient() 是因为 target_network 有自己对立的更新策略 -- 从 online network 周期性 (每 C 步) 拷贝模型参数
- target = tf.stop_gradient(self._build_target_q_op())
- # Heber 是平方差的变体
- loss = tf.compat.v1.losses.huber_loss(
- target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
- return self.optimizer.minimize(tf.reduce_mean(loss))
截图自 https://github.com/openai/gym/wiki/CartPole-v0
训练情况: 可见 Agent 获得了接近 200 的 Reward, 也就是几乎跑满了 200 步上限(即控制了杆子的稳定性没有过早倒掉结束)
下期见!
- [1] Mnih V, Kavukcuoglu K, Silver D, et al. Playing atari with deep reinforcement learning[J]. arXiv preprint arXiv:1312.5602, 2013.
- [2] Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep reinforcement learning[J]. nature, 2015, 518(7540): 529-533.
来源: https://zhuanlan.zhihu.com/p/341533851