上QQ阅读APP看书,第一时间看更新
2.2 使用TensorFlow 2.0的Embedding Layer
PyTorch平台对词嵌入有对应的网络层,同样TensorFlow平台也有对应的网络层。下面详细介绍TensorFlow中的Embedding Layer的使用。
2.2.1 语法格式
TensorFlow 2.0中Embedding Layer的语法格式如下:
tf.keras.layers.Embedding( input_dim, output_dim, embeddings_initializer='uniform', embeddings_regularizer=None, activity_regularizer=None, embeddings_constraint=None, mask_zero=False, input_length=None, **kwargs )
1)主要参数说明如下:
- input_dim, int>0。词汇表大小,即共有多少个不相同的词,对应PyTorch的num_embeddings参数。
- output_dim, int > 0。词向量的维度,对应PyTorch的embedding_dim参数。
- embeddings_initializer。Embeddings矩阵(即查询表)的初始化方法。
- mask_zero。如果mask_zero设置为True,则填充值为0,此时在词汇表中就不能使用索引0了。
- input_length。输入序列的长度,如果需要连接Flatten层再连接Dense层,这个参数是必须要有的,否则将报错。
2)输入说明如下:
输入一般为2维张量,其形状为(batch_size, input_length)。
3)输出说明如下:
输出一般为3维张量,其形状为(batch_size, input_length, output_dim)。
2.2.2 简单实例
下面我们来看一个简单实例,具体步骤如下。
1)定义一个语料,代码如下:
corpus=[ ["The", "weather", "will", "be", "nice", "tomorrow"], ["How", "are", "you", "doing", "today"], ["Hello", "world", "!"] ]
2)导入需要的模块。
import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Embedding import numpy as np
3)生成一个字典。
#获取语料不同单词,并过滤"!" word_set=set([i for item in corpus for i in item if i!='!']) word_dicts={} #索引从1开始,0用来填充 j=1 for i in word_set: word_dicts[i]=j j=j+1
4)用索引表示语料。
raw_inputs=[] for i in range(len(corpus)): raw_inputs.append([word_dicts[j] for j in corpus[i] if j!="!"]) padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs, padding='post') print(padded_inputs)
5)构建网络。
model = Sequential() model.add(Embedding(20, 4, input_length=6,mask_zero=True)) model.compile('rmsprop', 'mse') output_array = model.predict(padded_inputs) output_array.shape
6)查看运行结果。
output_array[1] array([[ 0.03433469, 0.0206447 , -0.03389787, -0.00570253], [ 0.00114531, 0.03147959, -0.02087148, -0.00851966], [-0.01190972, -0.02093003, 0.02987151, -0.04057767], [ 0.01103591, -0.01805868, -0.00409973, 0.01246386], [ 0.02508983, 0.04906926, -0.02865715, -0.00525292], [ 0.03823281, 0.01339761, 0.01344738, -0.03699453]], dtype=float32)
更多使用方法可参考TensorFlow官网(https://www.tensorflow.org/tutorials/text/word_embeddings)。