@@ -100,3 +100,121 @@ def train_one_batch(self, enc_inputs, dec_inputs, dec_outputs, pad):
100100
101101 def set_optimizer (self , opt ):
102102 self .opt = opt
103+
104+
105+ class TransformerDecoder (layer .Layer ):
106+ """TransformerDecoder is a stack of N decoder layers
107+ Args:
108+ tgt_n_token: the size of target vocab
109+ d_model: the number of expected features in the decoder inputs (default=512).
110+ n_head: the number of heads in the multi head attention models (default=8).
111+ dim_feedforward: the dimension of the feedforward network model (default=2048).
112+ n_layers: the number of sub-decoder-layers in the decoder (default=6).
113+ """
114+
115+ def __init__ (self , tgt_n_token , d_model = 512 , n_head = 8 , dim_feedforward = 2048 , n_layers = 6 ):
116+ super (TransformerDecoder , self ).__init__ ()
117+ self .tgt_n_token = tgt_n_token
118+ self .d_model = d_model
119+ self .n_head = n_head
120+ self .dim_feedforward = dim_feedforward
121+ self .n_layers = n_layers
122+
123+ # target_emb / pos_emb / n-layers
124+ self .target_emb = layer .Embedding (input_dim = tgt_n_token , output_dim = d_model )
125+ self .target_pos_emb = layer .Embedding (input_dim = tgt_n_token , output_dim = d_model )
126+ self .layers = []
127+ for _ in range (n_layers ):
128+ self .layers .append (TransformerDecoderLayer (d_model = d_model , n_head = n_head , dim_feedforward = dim_feedforward ))
129+
130+ def forward (self , dec_inputs , enc_inputs , enc_outputs ):
131+ """
132+ Args:
133+ dec_inputs: [batch_size, tgt_len]
134+ enc_inputs: [batch_size, src_len]
135+ enc_outputs: [batch_size, src_len, d_model]
136+
137+ """
138+
139+ # [batch_size, tgt_len, d_model]
140+ tgt_word_emb = self .target_emb (dec_inputs )
141+ self .target_pos_emb .initialize (dec_inputs )
142+ self .target_pos_emb .from_pretrained (W = TransformerDecoder ._get_sinusoid_encoding_table (self .tgt_n_token , self .d_model ),
143+ freeze = True )
144+ # [batch_size, tgt_len, d_model]
145+ tgt_pos_emb = self .target_pos_emb (dec_inputs )
146+ # [batch_size, tgt_len, d_model]
147+ dec_outputs = autograd .add (tgt_word_emb , tgt_pos_emb )
148+
149+ # dec_self_attn_pad_mask [batch_size, tgt_len, tgt_len]
150+ dec_self_attn_pad_mask = TransformerDecoder ._get_attn_pad_mask (dec_inputs , dec_inputs )
151+ # [batch_size, tgt_len, tgt_len]
152+ dec_self_attn_subsequent_mask = TransformerDecoder ._get_attn_subsequence_mask (dec_inputs )
153+
154+ # dec_self_attn_mask [batch_size, tgt_len, tgt_len]
155+ dec_self_attn_mask = tensor .gt ((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask ), 0 )
156+
157+ # dec_enc_attn_mask [batch_size, tgt_len, src_len]
158+ dec_enc_attn_mask = TransformerDecoder ._get_attn_pad_mask (dec_inputs , enc_inputs )
159+
160+ dec_self_attns , dec_enc_attns = [], []
161+
162+ for layer in self .layers :
163+ # dec_outputs: [batch_size, tgt_len, d_model],
164+ # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len],
165+ # dec_enc_attn: [batch_size, h_heads, tgt_len,src_len]
166+ dec_outputs , dec_self_attn , dec_enc_attn = layer (dec_outputs , enc_outputs , dec_self_attn_mask ,
167+ dec_enc_attn_mask )
168+ dec_self_attns .append (dec_self_attn )
169+ dec_enc_attns .append (dec_enc_attn )
170+ return dec_outputs , dec_self_attns , dec_enc_attns
171+
172+ @staticmethod
173+ def _get_attn_pad_mask (seq_q , seq_k ):
174+ """
175+ Args:
176+ seq_q: [batch_size, seq_len]
177+ seq_k: [batch_size, seq_len]
178+ Returns:
179+ [batch_size, seq_len, seq_len]
180+ """
181+
182+ batch_size , len_q = seq_q .shape
183+ batch_size , len_k = seq_k .shape
184+ seq_k_np = tensor .to_numpy (seq_k )
185+ pad_attn_mask_np = np .where (seq_k_np == 0 , 1 , 0 )
186+ pad_attn_mask_np .astype (np .int32 )
187+ pad_attn_mask_np = np .expand_dims (pad_attn_mask_np , axis = 1 )
188+ pad_attn_mask_np = np .broadcast_to (pad_attn_mask_np , (batch_size , len_q , len_k ))
189+ pad_attn_mask_np = tensor .from_numpy (pad_attn_mask_np )
190+ return pad_attn_mask_np
191+
192+ @staticmethod
193+ def _get_attn_subsequence_mask (seq ):
194+ """
195+ Args:
196+ seq: [batch_size, tgt_len]
197+
198+ Returns:
199+ """
200+ attn_shape = [seq .shape [0 ], seq .shape [1 ], seq .shape [1 ]]
201+
202+ # generate the upper triangular matrix, [batch_size, tgt_len, tgt_len]
203+ subsequence_mask = np .triu (np .ones (attn_shape ), k = 1 )
204+ subsequence_mask .astype (np .int32 )
205+ subsequence_mask = tensor .from_numpy (subsequence_mask )
206+ return subsequence_mask
207+
208+ @staticmethod
209+ def _get_sinusoid_encoding_table (n_position , d_model ):
210+ def cal_angle (position , hid_idx ):
211+ return position / np .power (10000 , 2 * (hid_idx // 2 ) / d_model )
212+
213+ def get_posi_angle_vec (position ):
214+ return [cal_angle (position , hid_j ) for hid_j in range (d_model )]
215+
216+ sinusoid_table = np .array ([get_posi_angle_vec (pos_i ) for pos_i in range (n_position )], np .float32 )
217+ sinusoid_table [:, 0 ::2 ] = np .sin (sinusoid_table [:, 0 ::2 ]) # Even bits use sine functions
218+ sinusoid_table [:, 1 ::2 ] = np .cos (sinusoid_table [:, 1 ::2 ]) # Cosine function for odd digits
219+ return tensor .Tensor (data = sinusoid_table , requires_grad = False )
220+
0 commit comments