- with tf.variable_scope('encoder') as scope:
- # RNN 编码器单元
- self.encoder_stacked_cell = rnn_cell(FLAGS, self.dropout,
- scope=scope)
- # 嵌入 RNN 编码器输入
- W_input = tf.get_variable("W_input",
- [FLAGS.en_vocab_size, FLAGS.num_hidden_units])
- self.embedded_encoder_inputs = rnn_inputs(FLAGS,
- self.encoder_inputs, FLAGS.en_vocab_size, scope=scope)
- #initial_state = encoder_stacked_cell.zero_state(FLAGS.batch_size, tf.float32)
- # RNN 编码器的输出
- self.all_encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn(
- cell=self.encoder_stacked_cell,
- inputs=self.embedded_encoder_inputs,
- sequence_length=self.en_seq_lens, time_major=False,
- dtype=tf.float32)
- with tf.variable_scope('decoder') as scope:
- # 初始状态是编码器的最后一个对应状态
- self.decoder_initial_state = self.encoder_state
- # RNN 解码器单元
- self.decoder_stacked_cell = rnn_cell(FLAGS, self.dropout,
- scope=scope)
- # 嵌入 RNN 解码器输入
- W_input = tf.get_variable("W_input",
- [FLAGS.sp_vocab_size, FLAGS.num_hidden_units])
- self.embedded_decoder_inputs = rnn_inputs(FLAGS, self.decoder_inputs,
- FLAGS.sp_vocab_size, scope=scope)
- # RNN 解码器的输出
- self.all_decoder_outputs, self.decoder_state = tf.nn.dynamic_rnn(
- cell=self.decoder_stacked_cell,
- inputs=self.embedded_decoder_inputs,
- sequence_length=self.sp_seq_lens, time_major=False,
- initial_state=self.decoder_initial_state)
- # Logit
- self.decoder_outputs_flat = tf.reshape(self.all_decoder_outputs,
- [-1, FLAGS.num_hidden_units])
- self.logits_flat = rnn_softmax(FLAGS, self.decoder_outputs_flat,
- scope=scope)
- # 损失屏蔽
- targets_flat = tf.reshape(self.targets, [-1])
- losses_flat = tf.nn.sparse_softmax_cross_entropy_with_logits(
- self.logits_flat, targets_flat)
- mask = tf.sign(tf.to_float(targets_flat))
- masked_losses = mask * losses_flat
- masked_losses = tf.reshape(masked_losses, tf.shape(self.targets))
- self.loss = tf.reduce_mean(
- tf.reduce_sum(masked_losses, reduction_indices=1))
来源: https://juejin.im/post/59fc1616f265da432b4a2d44