Skip to content

Commit 0b401e2

Browse files
author
yunfan
committed
[fix] star-transformer position embedding
1 parent 568dc66 commit 0b401e2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fastNLP/modules/encoder/star_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def norm_func(f, x):
6969
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)
7070

7171
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
72-
if self.pos_emb and False:
72+
if self.pos_emb:
7373
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
7474
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
7575
embs = embs + P

0 commit comments

Comments
 (0)