简介
使用 Keras 实现 Siamese Network 并进行语句相似度的计算
原理
Siamese Network 是指网络中包含两个或以上完全相同的子网络, 多应用于语句相似度计算, 人脸匹配, 签名鉴别等任务上
语句相似度计算: 输入两句话, 判断是否是一个意思
人脸匹配: 输入两张人脸, 判断是否是同一个人
签名鉴别: 输入两个签名, 判断是否是同一个人所写
以语句相似度计算为例, 两边的子网络从 Embedding 层到 LSTM 层等都是完全相同的, 整个模型称作 MaLSTM(Manhattan LSTM)
通过 LSTM 层的最后输出得到两句话的固定长度表示, 再使用以下公式计算两者的相似度, 相似度在 0 至 1 之间
数据
使用 Kaggle 上的 Quora 问题对数据, Quora 对应外国的知乎, https://www.kaggle.com/c/quora-question-pairs
训练集和测试集分别有 404290 和 3563475 条数据, 每条数据包括以下字段, 但测试集不包括 is_duplicate 字段
id: 问题对的 id
qid1: 问题 1 的 id
qid2: 问题 2 的 id
question1: 问题 1 的文本
question2: 问题 2 的文本
is_duplicate: 两个问题是不是意思一样, 0 或 1
实现
加载库
- # -*- coding: utf-8 -*-
- from keras.preprocessing.sequence import pad_sequences
- from keras.models import Model
- from keras.layers import Input, Embedding, LSTM, Lambda
- import keras.backend as K
- from keras.optimizers import Adam
- import pandas as pd
- import numpy as np
- from gensim.models import KeyedVectors
- from nltk.corpus import stopwords
- from sklearn.model_selection import train_test_split
- import matplotlib.pyplot as plt
- %matplotlib inline
- import re
- from tqdm import tqdm
- import pickle
加载训练集和测试集
- train_df = pd.read_csv('train.CSV')
- test_df = pd.read_csv('test.CSV')
- print(len(train_df), len(test_df))
- train_df.head()
加载 nltk(Natural Language Toolkit) 中的停用词, 并定义一个文本预处理函数
- # 如果报错 nltk 没有 stopwords 则下载
- # import nltk
- # nltk.download('stopwords')
- stops = set(stopwords.words('english'))
- def preprocess(text):
- # input: 'Hello are you ok?'
- # output: ['Hello', 'are', 'you', 'ok', '?']
- text = str(text)
- text = text.lower()
- text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]"," ", text) # 去掉其他符号
- text = re.sub(r"what's","what is ", text) # 缩写
- text = re.sub(r"\'s"," is ", text) # 缩写
- text = re.sub(r"\'ve"," have ", text) # 缩写
- text = re.sub(r"can't","cannot ", text) # 缩写
- text = re.sub(r"n't"," not ", text) # 缩写
- text = re.sub(r"i'm","i am ", text) # 缩写
- text = re.sub(r"\'re"," are ", text) # 缩写
- text = re.sub(r"\'d"," would ", text) # 缩写
- text = re.sub(r"\'ll"," will ", text) # 缩写
- text = re.sub(r",", " ", text) # 去除逗号
- text = re.sub(r"\.", " ", text) # 去除句号
- text = re.sub(r"!", "!", text) # 保留感叹号
- text = re.sub(r"\/", " ", text) # 去掉右斜杠
- text = re.sub(r"\^", "^", text) # 其他符号
- text = re.sub(r"\+", "+", text) # 其他符号
- text = re.sub(r"\-", "-", text) # 其他符号
- text = re.sub(r"\=", "=", text) # 其他符号
- text = re.sub(r"\'", " ", text) # 去掉单引号
- text = re.sub(r"(\d+)(k)", r"\g<1>000", text) # 把 30k 等替换成 30000
- text = re.sub(r":", ":", text) # 其他符号
- text = re.sub(r"e g", "eg", text) # 其他词
- text = re.sub(r"b g", "bg", text) # 其他词
- text = re.sub(r"u s", "american", text) # 其他词
- text = re.sub(r"\0s", "0", text) # 其他词
- text = re.sub(r"9 11", "911", text) # 其他词
- text = re.sub(r"e - mail", "email", text) # 其他词
- text = re.sub(r"j k", "jk", text) # 其他词
- text = re.sub(r"\s{2,}", " ", text) # 将多个空白符替换成一个空格
- return text.split()
加载 Google 预训练好的 300 维词向量
word2vec = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)
整理词典, 一共有 58564 个词, 将文本替换成整数序列表示, 获得词向量映射矩阵
- vocabulary = []
- word2id = {}
- id2word = {}
- for df in [train_df, test_df]:
- for i in tqdm(range(len(df))):
- row = df.iloc[i]
- for column in ['question1', 'question2']:
- q2n = []
- for Word in preprocess(row[column]):
- if Word in stops or Word not in word2vec.vocab:
- continue
- if Word not in vocabulary:
- word2id[Word] = len(vocabulary) + 1
- id2word[len(vocabulary) + 1] = Word
- vocabulary.append(Word)
- q2n.append(word2id[Word])
- else:
- q2n.append(word2id[Word])
- df.at[i, column] = q2n
- embedding_dim = 300
- embeddings = np.random.randn(len(vocabulary) + 1, embedding_dim)
- embeddings[0] = 0 # 零填充对应的词向量
- for index, Word in enumerate(vocabulary):
- embeddings[index] = word2vec.word_vec(Word)
- del word2vec
- print(len(vocabulary))
分割训练集和验证集, 将整数序列 padding 到统一长度
- maxlen = max(train_df.question1.map(lambda x: len(x)).max(),
- train_df.question2.map(lambda x: len(x)).max(),
- test_df.question1.map(lambda x: len(x)).max(),
- test_df.question2.map(lambda x: len(x)).max())
- valid_size = 40000
- train_size = len(train_df) - valid_size
- X = train_df[['question1', 'question2']]
- Y = train_df['is_duplicate']
- X_train, X_valid, Y_train, Y_valid = train_test_split(X, Y, test_size=valid_size)
- X_train = {'left': X_train.question1.values, 'right': X_train.question2.values}
- X_valid = {'left': X_valid.question1.values, 'right': X_valid.question2.values}
- Y_train = np.expand_dims(Y_train.values, axis=-1)
- Y_valid = np.expand_dims(Y_valid.values, axis=-1)
- # 前向填充或截断
- X_train['left'] = np.array(pad_sequences(X_train['left'], maxlen=maxlen))
- X_train['right'] = np.array(pad_sequences(X_train['right'], maxlen=maxlen))
- X_valid['left'] = np.array(pad_sequences(X_valid['left'], maxlen=maxlen))
- X_valid['right'] = np.array(pad_sequences(X_valid['right'], maxlen=maxlen))
- print(X_train['left'].shape, X_train['right'].shape)
- print(X_valid['left'].shape, X_valid['right'].shape)
- print(Y_train.shape, Y_valid.shape)
定义模型并训练
- hidden_size = 128
- gradient_clipping_norm = 1.25
- batch_size = 64
- epochs = 20
- def exponent_neg_manhattan_distance(args):
- left, right = args
- return K.exp(-K.sum(K.abs(left - right), axis=1, keepdims=True))
- left_input = Input(shape=(None,), dtype='int32')
- right_input = Input(shape=(None,), dtype='int32')
- embedding_layer = Embedding(len(embeddings), embedding_dim, weights=[embeddings], input_length=maxlen, trainable=False)
- embedded_left = embedding_layer(left_input)
- embedded_right = embedding_layer(right_input)
- shared_lstm = LSTM(hidden_size)
- left_output = shared_lstm(embedded_left)
- right_output = shared_lstm(embedded_right)
- malstm_distance = Lambda(exponent_neg_manhattan_distance, output_shape=(1,))([left_output, right_output])
- malstm = Model([left_input, right_input], malstm_distance)
- optimizer = Adam(clipnorm=gradient_clipping_norm)
- malstm.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])
- history = malstm.fit([X_train['left'], X_train['right']], Y_train, batch_size=batch_size, epochs=epochs,
- validation_data=([X_valid['left'], X_valid['right']], Y_valid))
绘制训练过程中的正确率曲线和损失函数曲线
- # Plot Accuracy
- plt.plot(history.history['acc'])
- plt.plot(history.history['val_acc'])
- plt.title('Model Accuracy')
- plt.ylabel('Accuracy')
- plt.xlabel('Epoch')
- plt.legend(['Train', 'Validation'], loc='upper left')
- plt.show()
- # Plot Loss
- plt.plot(history.history['loss'])
- plt.plot(history.history['val_loss'])
- plt.title('Model Loss')
- plt.ylabel('Loss')
- plt.xlabel('Epoch')
- plt.legend(['Train', 'Validation'], loc='upper right')
- plt.show()
训练集损失不断降低, 但验证集损失趋于平缓, 说明模型泛化能力还不够
训练集正确率提升到了 86% 以上, 而验证集正确率维持在 80% 左右, 模型有待进一步改进
保存模型, 以便后续使用
- malstm.save('malstm.h5')
- with open('data.pkl', 'wb') as fw:
- pickle.dump({'word2id': word2id, 'id2word': id2word}, fw)
在单机上使用训练好的模型做个简单测试, 从训练集中随机拿出一些样本, 观察模型分类的结果是否和标签一致, 主要是熟悉下如何应用模型进行推断
- # -*- coding: utf-8 -*-
- from keras.preprocessing.sequence import pad_sequences
- from keras.models import Model, load_model
- import pandas as pd
- import numpy as np
- from nltk.corpus import stopwords
- import re
- import pickle
- with open('data.pkl', 'rb') as fr:
- data = pickle.load(fr)
- word2id = data['word2id']
- id2word = data['id2word']
- train_df = pd.read_csv('train.CSV')
- stops = set(stopwords.words('english'))
- def preprocess(text):
- # input: 'Hello are you ok?'
- # output: ['Hello', 'are', 'you', 'ok', '?']
- text = str(text)
- text = text.lower()
- text = re.sub(r"[^A-Za-z0-9^,!.\/'+-=]"," ", text) # 去掉其他符号
- text = re.sub(r"what's","what is ", text) # 缩写
- text = re.sub(r"\'s"," is ", text) # 缩写
- text = re.sub(r"\'ve"," have ", text) # 缩写
- text = re.sub(r"can't","cannot ", text) # 缩写
- text = re.sub(r"n't"," not ", text) # 缩写
- text = re.sub(r"i'm","i am ", text) # 缩写
- text = re.sub(r"\'re"," are ", text) # 缩写
- text = re.sub(r"\'d"," would ", text) # 缩写
- text = re.sub(r"\'ll"," will ", text) # 缩写
- text = re.sub(r",", " ", text) # 去除逗号
- text = re.sub(r"\.", " ", text) # 去除句号
- text = re.sub(r"!", "!", text) # 保留感叹号
- text = re.sub(r"\/", " ", text) # 去掉右斜杠
- text = re.sub(r"\^", "^", text) # 其他符号
- text = re.sub(r"\+", "+", text) # 其他符号
- text = re.sub(r"\-", "-", text) # 其他符号
- text = re.sub(r"\=", "=", text) # 其他符号
- text = re.sub(r"\'", " ", text) # 去掉单引号
- text = re.sub(r"(\d+)(k)", r"\g<1>000", text) # 把 30k 等替换成 30000
- text = re.sub(r":", ":", text) # 其他符号
- text = re.sub(r"e g", "eg", text) # 其他词
- text = re.sub(r"b g", "bg", text) # 其他词
- text = re.sub(r"u s", "american", text) # 其他词
- text = re.sub(r"\0s", "0", text) # 其他词
- text = re.sub(r"9 11", "911", text) # 其他词
- text = re.sub(r"e - mail", "email", text) # 其他词
- text = re.sub(r"j k", "jk", text) # 其他词
- text = re.sub(r"\s{2,}", " ", text) # 将多个空白符替换成一个空格
- return text.split()
- malstm = load_model('malstm.h5')
- correct = 0
- for i in range(5):
- print('Testing Case:', i + 1)
- random_sample = dict(train_df.iloc[np.random.randint(len(train_df))])
- left = random_sample['question1']
- right = random_sample['question2']
- print('Origin Questions...')
- print('==', left)
- print('==', right)
- left = preprocess(left)
- right = preprocess(right)
- print('Preprocessing...')
- print('==', left)
- print('==', right)
- left = [word2id[w] for w in left if w in word2id]
- right = [word2id[w] for w in right if w in word2id]
- print('To ids...')
- print('==', left, [id2word[i] for i in left])
- print('==', right, [id2word[i] for i in right])
- left = np.expand_dims(left, 0)
- right = np.expand_dims(right, 0)
- maxlen = max(left.shape[-1], right.shape[-1])
- left = pad_sequences(left, maxlen=maxlen)
- right = pad_sequences(right, maxlen=maxlen)
- print('Padding...')
- print('==', left.shape)
- print('==', right.shape)
- pred = malstm.predict([left, right])
- pred = 1 if pred[0][0]> 0.5 else 0
- print('True:', random_sample['is_duplicate'])
- print('Pred:', pred)
- if pred == random_sample['is_duplicate']:
- correct += 1
- print(correct / 5)
参考
- How to predict Quora Question Pairs using Siamese Manhattan LSTM:
- Siamese Recurrent Architectures for Learning Sentence Similarity:
视频讲解课程
深度有趣 (一)
来源: https://juejin.im/post/5ba4ddca6fb9a05d011cdcd5