-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
177 lines (153 loc) · 5.44 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
import json
from tqdm import tqdm
import numpy as np
class Seq2SeqInputExample(object):
def __init__(self, passage, answer, question):
super(Seq2SeqInputExample, self).__init__()
self.passage = passage
self.answer = answer
self.question = question
class Seq2SeqInputFeature(object):
def __init__(self, input_ids, token_type_ids, position_ids, attention_mask, answer_tag_ids=None):
super().__init__()
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.position_ids = position_ids
self.attention_mask = attention_mask
self.answer_tag_ids = answer_tag_ids
class Seq2SeqDataset(Dataset):
"""
针对特定数据集,定义一个相关的取数据的方式
"""
def __init__(self, features):
# 一般init函数是加载所有数据
super(Seq2SeqDataset, self).__init__()
self.features = features
def __getitem__(self, i):
# 得到单个数据
return self.features[i]
def __len__(self):
return len(self.features)
def read_file(input_dir):
with open(input_dir, 'r', encoding='utf-8') as f:
data = json.load(f)
examples = []
for part in data:
examples.append(
Seq2SeqInputExample(
passage=part['passage'],
answer=part['answer'],
question=part['question'],
)
)
return examples
def seq2seq_convert_example_to_feature(examples, tokenizer, max_length, is_test=False):
features = []
print("max length = ", max_length)
print('example nums = ', len(examples))
for example in tqdm(examples, desc='convert to features', leave=True):
# print(example.question)
# if is_test:
# input_ids, token_type_ids = \
# tokenizer.encode(example.passage+'#'+example.answer,
# maxlen=max_length)
# else:
# input_ids, token_type_ids = \
# tokenizer.encode(example.passage+'#'+example.answer,
# example.question,
# maxlen=max_length,
# )
if is_test:
input_ids, token_type_ids, answer_tag_ids = \
tokenizer.encode_plus(passage=example.passage,
answer=example.answer,
maxlen=max_length)
else:
input_ids, token_type_ids, answer_tag_ids = \
tokenizer.encode_plus(passage=example.passage,
answer=example.answer,
question=example.question,
maxlen=max_length,
)
position_ids = [i for i in range(len(input_ids))]
attention_mask = [1 for i in range(len(input_ids))]
features.append(
Seq2SeqInputFeature(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
answer_tag_ids=answer_tag_ids,
)
)
return features
def collate_batch(features):
features = merge_dict(features)
max_len = max([len(f) for f in features['input_ids']])
batch = {}
for k, v in features.items():
if v is not None and not isinstance(v, str):
values = padding(k, v, max_len)
batch[k] = torch.tensor(values, dtype=torch.long)
return batch
def merge_dict(features):
first = features[0]
batch = {}
for k, v in vars(first).items():
if v is not None:
values = [getattr(f, k) for f in features]
else:
values = None
batch[k] = values
return batch
def padding(feature_name, features, max_len):
values = list()
for f in features:
pad_len = max_len - len(f)
if feature_name != 'position_ids':
pad_part = [0] * pad_len
else:
pad_part = f[-1:] * pad_len
values.append(f + pad_part)
return values
def get_dataloader(input_dir,
tokenizer,
max_length=512,
batch_size=1,
shuffle=True):
train_dataset = Seq2SeqDataset(
seq2seq_convert_example_to_feature(
examples=read_file(input_dir),
tokenizer=tokenizer,
max_length=max_length,
)
)
train_sampler = RandomSampler(train_dataset)
return \
DataLoader(
train_dataset,
batch_size=batch_size,
sampler=train_sampler,
collate_fn=collate_batch,
)
def get_test_dataloader(input_dir,
tokenizer,
max_length=400,
batch_size=1):
test_dataloader = Seq2SeqDataset(
seq2seq_convert_example_to_feature(
examples=read_file(input_dir),
tokenizer=tokenizer,
max_length=max_length,
is_test=True,
),
)
sequential_sampler = SequentialSampler(test_dataloader)
return DataLoader(
test_dataloader,
batch_size=batch_size,
sampler=sequential_sampler,
collate_fn=collate_batch,
)