【深度学习笔记】Bert的参数分析

【深度学习笔记】Bert的参数分析

本篇文章我们将拆分Bert,细究Bert的结构以及每一层的参数个数

我们以bert-base为例(768维):

普通bert:

bert的模型如下(省略多层):

Model: "model"

__________________________________________________________________________________________________

Layer (type) Output Shape Param # Connected to

==================================================================================================

Input-Token (InputLayer) [(None, None)] 0

__________________________________________________________________________________________________

Input-Segment (InputLayer) [(None, None)] 0

__________________________________________________________________________________________________

Embedding-Token (Embedding) multiple 16226304 Input-Token[0][0]

MLM-Norm[0][0]

__________________________________________________________________________________________________

Embedding-Segment (Embedding) (None, None, 768) 1536 Input-Segment[0][0]

__________________________________________________________________________________________________

Embedding-Token-Segment (Add) (None, None, 768) 0 Embedding-Token[0][0]

Embedding-Segment[0][0]

__________________________________________________________________________________________________

Embedding-Position (PositionEmb (None, None, 768) 393216 Embedding-Token-Segment[0][0]

__________________________________________________________________________________________________

Embedding-Norm (LayerNormalizat (None, None, 768) 1536 Embedding-Position[0][0]

__________________________________________________________________________________________________

Embedding-Dropout (Dropout) (None, None, 768) 0 Embedding-Norm[0][0]

__________________________________________________________________________________________________

Transformer-0-MultiHeadSelfAtte (None, None, 768) 2362368 Embedding-Dropout[0][0]

Embedding-Dropout[0][0]

Embedding-Dropout[0][0]

__________________________________________________________________________________________________

Transformer-0-MultiHeadSelfAtte (None, None, 768) 0 Transformer-0-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-0-MultiHeadSelfAtte (None, None, 768) 0 Embedding-Dropout[0][0]

Transformer-0-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-0-MultiHeadSelfAtte (None, None, 768) 1536 Transformer-0-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-0-FeedForward (Feed (None, None, 768) 4722432 Transformer-0-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-0-FeedForward-Dropo (None, None, 768) 0 Transformer-0-FeedForward[0][0]

__________________________________________________________________________________________________

Transformer-0-FeedForward-Add ( (None, None, 768) 0 Transformer-0-MultiHeadSelfAttent

Transformer-0-FeedForward-Dropout

__________________________________________________________________________________________________

Transformer-0-FeedForward-Norm (None, None, 768) 1536 Transformer-0-FeedForward-Add[0][

__________________________________________________________________________________________________

Transformer-1-MultiHeadSelfAtte (None, None, 768) 2362368 Transformer-0-FeedForward-Norm[0]

Transformer-0-FeedForward-Norm[0]

Transformer-0-FeedForward-Norm[0]

__________________________________________________________________________________________________

Transformer-1-MultiHeadSelfAtte (None, None, 768) 0 Transformer-1-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-1-MultiHeadSelfAtte (None, None, 768) 0 Transformer-0-FeedForward-Norm[0]

Transformer-1-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-1-MultiHeadSelfAtte (None, None, 768) 1536 Transformer-1-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-1-FeedForward (Feed (None, None, 768) 4722432 Transformer-1-MultiHeadSelfAttent

__________________________________________________________________________________________________

Transformer-1-FeedForward-Dropo (None, None, 768) 0 Transformer-1-FeedForward[0][0]

__________________________________________________________________________________________________

Transformer-1-FeedForward-Add ( (None, None, 768) 0 Transformer-1-MultiHeadSelfAttent

Transformer-1-FeedForward-Dropout

__________________________________________________________________________________________________

Transformer-1-FeedForward-Norm (None, None, 768) 1536 Transformer-1-FeedForward-Add[0][

__________________________________________________________________________________________________

下面我们简单剖析一下各部分参数:

首先是输入:

embedding部分,bert使用了embedding、token type(用来区分两个句子)和position embedding三部分。

embedding就是 (这里以词典大小21128为例) :

voab size * embedding size = 21128*768=16226304。

__________________________________________________________________________________________________

Embedding-Token (Embedding) multiple 16226304 Input-Token[0][0]

MLM-Norm[0][0]

token type:

使用0和1标记句子(比如NSP任务时区分两个句子):

768*2=1536。

__________________________________________________________________________________________________

Embedding-Segment (Embedding) (None, None, 768) 1536 Input-Segment[0][0]

position embedding:

max length * embedding size = 512*768=393216

__________________________________________________________________________________________________

Embedding-Position (PositionEmb (None, None, 768) 393216 Embedding-Token-Segment[0][0]

_

然后Bert在embedding部分还有一个layer Normalization,因此还要有768*2个参数( α 和 β )

__________________________________________________________________________________________________

Embedding-Norm (LayerNormalizat (None, None, 768) 1536 Embedding-Position[0][0]

搞懂了embedding的参数,下面就是Transformer的参数,为了简便,这里只介绍一层。

首先是多头注意力:

bert base使用了12头注意力机制,QKV维度为64维度,同时最后还需要一个O矩阵,将12头注意力结合。

因此总参数就是: embedding size* head nub * qkv size * len(qkv)[三个映射矩阵] + (head nub* qkv size)* embedding size[多头结果拼接后处理] + qkvo bias= 768*12*64*3 + 12*64*768+ 768*4 =2362368

(这里最后的768*4 分别为Q矩阵、K矩阵、V矩阵的偏置以及最后的O矩阵的偏置。)

__________________________________________________________________________________________________

Transformer-0-MultiHeadSelfAtte (None, None, 768) 2362368 Embedding-Dropout[0][0]

Embedding-Dropout[0][0]

Embedding-Dropout[0][0]

随后是多头注意力机制的LN:768*2=1536

__________________________________________________________________________________________________

Transformer-0-MultiHeadSelfAtte (None, None, 768) 1536 Transformer-0-MultiHeadSelfAttent

随后,是全连接层:

Bert中隐藏层个数采用了传统的4*input的大小,因此为4*768=3072。

因此,这部分参数为:embedding size * hidden size + bias + hidden size * embedding size + bias = 768*3072+3072+3072*768+768=4722432

__________________________________________________________________________________________________

Transformer-0-FeedForward (Feed (None, None, 768) 4722432 Transformer-0-MultiHeadSelfAttent

然后是LN:768*2=1536

__________________________________________________________________________________________________

Transformer-0-FeedForward-Norm (None, None, 768) 1536 Transformer-0-FeedForward-Add[0][

然后就是下一层Transformer 以此类推。

bert的Conditional Layer Normalization:

使用了Conditional Layer Normalization后,bert的LayerNormalizattion变为198144个。

__________________________________________________________________________________________________

Embedding-Norm (LayerNormalizat (None, None, 768) 198144 Embedding-Position[0][0]

reshape[0][0]

由于 β、γ 没有任何变化,还是1536个参数,我们可以分析一下这多出来的196608是从哪来的。

由于我们需要对 β、γ 进行相同的变换,因此参数个数也是相同的,我们分析的参数格式可以进一步缩小为98304个。

而我们之前提到,需要将c的128维度升到768维,如果不考虑bias偏置,只做矩阵变换(没有偏置的单层神经网络实际上就是矩阵变换),恰好是768*128=98304个参数。

详见 Conditional Layer Normalization。

bert的mlm任务:

__________________________________________________________________________________________________

Transformer-11-FeedForward-Norm (None, None, 768) 1536 Transformer-11-FeedForward-Add[0]

__________________________________________________________________________________________________

MLM-Dense (Dense) (None, None, 768) 590592 Transformer-11-FeedForward-Norm[0

__________________________________________________________________________________________________

MLM-Norm (LayerNormalization) (None, None, 768) 1536 MLM-Dense[0][0]

__________________________________________________________________________________________________

MLM-Bias (BiasAdd) (None, None, 21128) 21128 Embedding-Token[1][0]

__________________________________________________________________________________________________

MLM-Activation (Activation) (None, None, 21128) 0 MLM-Bias[0][0]

__________________________________________________________________________________________________

cross_entropy (CrossEntropy) (None, None, 21128) 0 Input-Token[0][0]

MLM-Activation[0][0]

==================================================================================================

这里主要多了MLM-Dense、MLM-Norm 、MLM-Bias几处。

MLM-Dense为768*768+768=590592。

其他两个就没啥可说的了。

bert的mlm任务实现方式

相关推荐