-
Notifications
You must be signed in to change notification settings - Fork 1
/
BERT-Question Answering.py
82 lines (60 loc) · 3.94 KB
/
BERT-Question Answering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
## Original code from https://mccormickml.com/2020/03/10/question-answering-with-a-fine-tuned-BERT/
## Modifications made by Deniz Askin
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import textwrap
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
def answer_question(question, answer_text):
'''
Takes a `question` string and an `answer_text` string (which contains the
answer), and identifies the words within the `answer_text` that are the
answer. Prints them out.
'''
# ======== Tokenize ========
# Apply the tokenizer to the input text, treating them as a text-pair.
input_ids = tokenizer.encode(question, answer_text)
# Report how long the input sequence is.
## print('Query has {:,} tokens.\n'.format(len(input_ids)))
# ======== Set Segment IDs ========
# Search the input_ids for the first instance of the `[SEP]` token.
sep_index = input_ids.index(tokenizer.sep_token_id)
# The number of segment A tokens includes the [SEP] token istelf.
num_seg_a = sep_index + 1
# The remainder are segment B.
num_seg_b = len(input_ids) - num_seg_a
# Construct the list of 0s and 1s.
segment_ids = [0]*num_seg_a + [1]*num_seg_b
# There should be a segment_id for every input token.
assert len(segment_ids) == len(input_ids)
# ======== Evaluate ========
# Run our example question through the model.
start_scores, end_scores = model(torch.tensor([input_ids]), # The tokens representing our input text.
token_type_ids=torch.tensor([segment_ids])) # The segment IDs to differentiate question from answer_text
# ======== Reconstruct Answer ========
# Find the tokens with the highest `start` and `end` scores.
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
# Get the string versions of the input tokens.
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Start with the first token.
answer = tokens[answer_start]
# Select the remaining answer tokens and join them with whitespace.
for i in range(answer_start + 1, answer_end + 1):
# If it's a subword token, then recombine it with the previous token.
if tokens[i][0:2] == '##':
answer += tokens[i][2:]
# Otherwise, add a space then the token.
else:
answer += ' ' + tokens[i]
print('Answer: "' + answer + '"')
# Wrap text to 80 characters.
wrapper = textwrap.TextWrapper(width=80)
# Enter the text to be queried
bert_abstract = "We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation models (Peters et al., 2018a; Radford et al., 2018), BERT is designed to pretrain deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a result, the pre-trained BERT model can be finetuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial taskspecific architecture modifications. BERT is conceptually simple and empirically powerful. It obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE score to 80.5% (7.7% point absolute improvement), MultiNLI accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1 question answering Test F1 to 93.2 (1.5 point absolute improvement) and SQuAD v2.0 Test F1 to 83.1 (5.1 point absolute improvement)."
print("Text:",wrapper.fill(bert_abstract)+"\n")
# Enter the question
question = "What does the 'B' in BERT stand for?"
print(question+"\n")
answer_question(question, bert_abstract)