-
Notifications
You must be signed in to change notification settings - Fork 3
/
bert_loader.py
33 lines (25 loc) · 746 Bytes
/
bert_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torch import nn
from torch.autograd import Variable
from holder import *
from util import *
# the elmo loader
# it takes no input but the current example idx
# encodings are actually loaded from cached embeddings
class BertLoader(torch.nn.Module):
def __init__(self, opt, shared):
super(BertLoader, self).__init__()
self.opt = opt
self.shared = shared
def forward(self, concated, char_concated, bert_pack):
bert_enc = self.shared.res_map['bert_concated']
bert_enc = Variable(bert_enc, requires_grad=False)
if self.opt.gpuid != -1:
bert_enc = bert_enc.cuda(self.opt.gpuid)
if self.opt.fp16 == 1:
bert_enc = bert_enc.half()
return bert_enc
def begin_pass(self):
pass
def end_pass(self):
pass