Skip to content

Commit 33ad565

Browse files
authored
Merge pull request #1350 from calmdown539/dev-postgresql
Add the implementations for the transformer decoder layer
2 parents 5ed4ccc + ed71da2 commit 33ad565

File tree

1 file changed

+118
-0
lines changed
  • examples/singa_peft/examples/model

1 file changed

+118
-0
lines changed

examples/singa_peft/examples/model/trans.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,121 @@ def train_one_batch(self, enc_inputs, dec_inputs, dec_outputs, pad):
100100

101101
def set_optimizer(self, opt):
102102
self.opt = opt
103+
104+
105+
class TransformerDecoder(layer.Layer):
106+
"""TransformerDecoder is a stack of N decoder layers
107+
Args:
108+
tgt_n_token: the size of target vocab
109+
d_model: the number of expected features in the decoder inputs (default=512).
110+
n_head: the number of heads in the multi head attention models (default=8).
111+
dim_feedforward: the dimension of the feedforward network model (default=2048).
112+
n_layers: the number of sub-decoder-layers in the decoder (default=6).
113+
"""
114+
115+
def __init__(self, tgt_n_token, d_model=512, n_head=8, dim_feedforward=2048, n_layers=6):
116+
super(TransformerDecoder, self).__init__()
117+
self.tgt_n_token = tgt_n_token
118+
self.d_model = d_model
119+
self.n_head = n_head
120+
self.dim_feedforward = dim_feedforward
121+
self.n_layers = n_layers
122+
123+
# target_emb / pos_emb / n-layers
124+
self.target_emb = layer.Embedding(input_dim=tgt_n_token, output_dim=d_model)
125+
self.target_pos_emb = layer.Embedding(input_dim=tgt_n_token, output_dim=d_model)
126+
self.layers = []
127+
for _ in range(n_layers):
128+
self.layers.append(TransformerDecoderLayer(d_model=d_model, n_head=n_head, dim_feedforward=dim_feedforward))
129+
130+
def forward(self, dec_inputs, enc_inputs, enc_outputs):
131+
"""
132+
Args:
133+
dec_inputs: [batch_size, tgt_len]
134+
enc_inputs: [batch_size, src_len]
135+
enc_outputs: [batch_size, src_len, d_model]
136+
137+
"""
138+
139+
# [batch_size, tgt_len, d_model]
140+
tgt_word_emb = self.target_emb(dec_inputs)
141+
self.target_pos_emb.initialize(dec_inputs)
142+
self.target_pos_emb.from_pretrained(W=TransformerDecoder._get_sinusoid_encoding_table(self.tgt_n_token, self.d_model),
143+
freeze=True)
144+
# [batch_size, tgt_len, d_model]
145+
tgt_pos_emb = self.target_pos_emb(dec_inputs)
146+
# [batch_size, tgt_len, d_model]
147+
dec_outputs = autograd.add(tgt_word_emb, tgt_pos_emb)
148+
149+
# dec_self_attn_pad_mask [batch_size, tgt_len, tgt_len]
150+
dec_self_attn_pad_mask = TransformerDecoder._get_attn_pad_mask(dec_inputs, dec_inputs)
151+
# [batch_size, tgt_len, tgt_len]
152+
dec_self_attn_subsequent_mask = TransformerDecoder._get_attn_subsequence_mask(dec_inputs)
153+
154+
# dec_self_attn_mask [batch_size, tgt_len, tgt_len]
155+
dec_self_attn_mask = tensor.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
156+
157+
# dec_enc_attn_mask [batch_size, tgt_len, src_len]
158+
dec_enc_attn_mask = TransformerDecoder._get_attn_pad_mask(dec_inputs, enc_inputs)
159+
160+
dec_self_attns, dec_enc_attns = [], []
161+
162+
for layer in self.layers:
163+
# dec_outputs: [batch_size, tgt_len, d_model],
164+
# dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len],
165+
# dec_enc_attn: [batch_size, h_heads, tgt_len,src_len]
166+
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask,
167+
dec_enc_attn_mask)
168+
dec_self_attns.append(dec_self_attn)
169+
dec_enc_attns.append(dec_enc_attn)
170+
return dec_outputs, dec_self_attns, dec_enc_attns
171+
172+
@staticmethod
173+
def _get_attn_pad_mask(seq_q, seq_k):
174+
"""
175+
Args:
176+
seq_q: [batch_size, seq_len]
177+
seq_k: [batch_size, seq_len]
178+
Returns:
179+
[batch_size, seq_len, seq_len]
180+
"""
181+
182+
batch_size, len_q = seq_q.shape
183+
batch_size, len_k = seq_k.shape
184+
seq_k_np = tensor.to_numpy(seq_k)
185+
pad_attn_mask_np = np.where(seq_k_np == 0, 1, 0)
186+
pad_attn_mask_np.astype(np.int32)
187+
pad_attn_mask_np = np.expand_dims(pad_attn_mask_np, axis=1)
188+
pad_attn_mask_np = np.broadcast_to(pad_attn_mask_np, (batch_size, len_q, len_k))
189+
pad_attn_mask_np = tensor.from_numpy(pad_attn_mask_np)
190+
return pad_attn_mask_np
191+
192+
@staticmethod
193+
def _get_attn_subsequence_mask(seq):
194+
"""
195+
Args:
196+
seq: [batch_size, tgt_len]
197+
198+
Returns:
199+
"""
200+
attn_shape = [seq.shape[0], seq.shape[1], seq.shape[1]]
201+
202+
# generate the upper triangular matrix, [batch_size, tgt_len, tgt_len]
203+
subsequence_mask = np.triu(np.ones(attn_shape), k=1)
204+
subsequence_mask.astype(np.int32)
205+
subsequence_mask = tensor.from_numpy(subsequence_mask)
206+
return subsequence_mask
207+
208+
@staticmethod
209+
def _get_sinusoid_encoding_table(n_position, d_model):
210+
def cal_angle(position, hid_idx):
211+
return position / np.power(10000, 2 * (hid_idx // 2) / d_model)
212+
213+
def get_posi_angle_vec(position):
214+
return [cal_angle(position, hid_j) for hid_j in range(d_model)]
215+
216+
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)], np.float32)
217+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # Even bits use sine functions
218+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # Cosine function for odd digits
219+
return tensor.Tensor(data=sinusoid_table, requires_grad=False)
220+

0 commit comments

Comments
 (0)