深入浅出Embedding:原理解析与应用实践
上QQ阅读APP看书,第一时间看更新

2.2 使用TensorFlow 2.0的Embedding Layer

PyTorch平台对词嵌入有对应的网络层,同样TensorFlow平台也有对应的网络层。下面详细介绍TensorFlow中的Embedding Layer的使用。

2.2.1 语法格式

TensorFlow 2.0中Embedding Layer的语法格式如下:

tf.keras.layers.Embedding(
input_dim, output_dim, embeddings_initializer='uniform',
embeddings_regularizer=None, activity_regularizer=None,
embeddings_constraint=None, mask_zero=False, input_length=None, **kwargs
)

1)主要参数说明如下:

  • input_dim, int>0。词汇表大小,即共有多少个不相同的词,对应PyTorch的num_embeddings参数。
  • output_dim, int > 0。词向量的维度,对应PyTorch的embedding_dim参数。
  • embeddings_initializer。Embeddings矩阵(即查询表)的初始化方法。
  • mask_zero。如果mask_zero设置为True,则填充值为0,此时在词汇表中就不能使用索引0了。
  • input_length。输入序列的长度,如果需要连接Flatten层再连接Dense层,这个参数是必须要有的,否则将报错。

2)输入说明如下:

输入一般为2维张量,其形状为(batch_size, input_length)。

3)输出说明如下:

输出一般为3维张量,其形状为(batch_size, input_length, output_dim)。

2.2.2 简单实例

下面我们来看一个简单实例,具体步骤如下。

1)定义一个语料,代码如下:

corpus=[
    ["The", "weather", "will", "be", "nice", "tomorrow"],
    ["How", "are", "you", "doing", "today"],
    ["Hello", "world", "!"]
]

2)导入需要的模块。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np

3)生成一个字典。

#获取语料不同单词,并过滤"!"
word_set=set([i for item in corpus for i in item if i!='!'])
word_dicts={}
#索引从1开始,0用来填充
j=1
for i in word_set:
    word_dicts[i]=j
    j=j+1

4)用索引表示语料。

raw_inputs=[]
for i in range(len(corpus)):
    raw_inputs.append([word_dicts[j]  for j in corpus[i] if j!="!"])
padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,
    padding='post')
print(padded_inputs)

5)构建网络。

model = Sequential()
model.add(Embedding(20, 4, input_length=6,mask_zero=True))
model.compile('rmsprop', 'mse')
output_array = model.predict(padded_inputs)
output_array.shape

6)查看运行结果。

output_array[1]
array([[ 0.03433469,  0.0206447 , -0.03389787, -0.00570253],
       [ 0.00114531,  0.03147959, -0.02087148, -0.00851966],
       [-0.01190972, -0.02093003,  0.02987151, -0.04057767],
       [ 0.01103591, -0.01805868, -0.00409973,  0.01246386],
       [ 0.02508983,  0.04906926, -0.02865715, -0.00525292],
       [ 0.03823281,  0.01339761,  0.01344738, -0.03699453]],
dtype=float32)

更多使用方法可参考TensorFlow官网(https://www.tensorflow.org/tutorials/text/word_embeddings)。