Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unable to run GTS on a custom dataset #8

Open
indranilByjus opened this issue Sep 27, 2021 · 4 comments
Open

unable to run GTS on a custom dataset #8

indranilByjus opened this issue Sep 27, 2021 · 4 comments
Labels
bug Something isn't working

Comments

@indranilByjus
Copy link

The module seems to run fine with the provided datasets.
But it throws error on when I've included a custom dataset.

Errortrace:

MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py line 220, in train_tree
current_num = current_nums_embeddings[idx, i - num_start].unsqueeze(0)
IndexError: index 124 is out of bounds for dimension 1 with size 118
@LYH-YF
Copy link
Owner

LYH-YF commented Sep 28, 2021

I'm not sure the specific reason. But you could check the value of these variables below.

dataset.copy_nums
dataset.num_start
dataset.generate_list
dataset.out_idx2symbol

they are all about the decoder's vocabulary of GTS, or other models. Please check if they are currect.

Second, you may pay attention to the inputs(batch_data) of the model.
At the line where throws the error , current_nums_embeddings means all number embedding at current decoding step (generate number + copy number). The size of it should be [batch_size, 118, hidden_size] , 118 is the sum of generate size (static in different batches) and copy size (dynamic in different batches, it's up to max(batch_data["num size"])). You could check if batch_data["num size"] is currect.
Another point is batch_data[num stack], in GTS, if a number appears twice or more in question sentence (one number has two position). So it has two optional symbols to generate. Which symbol to choose is decided while decoding. So target token is replaced by UNK_token, when decoding, choose the symbol which has maximal score as target symbol. batch_data["num stack"] means candidate symbols for UNK_token.If UNK_token is not replaced by candidate symbols currectly, it may cause the index out of bounds. So please check if batch_data["num stack"] is currect.

@LYH-YF
Copy link
Owner

LYH-YF commented Sep 28, 2021

Code for building number stack
MWPToolkit/mwptoolkit/data/dataset/abstactdataset.py line 192

    def _build_num_stack(self, equation, num_list):
        num_stack = []
        for word in equation:
            temp_num = []
            flag_not = True
            if word not in self.dataset.out_idx2symbol:
                flag_not = False
                if "NUM" in word:
                    temp_num.append(int(word[4:]))
                for i, j in enumerate(num_list):
                    if j == word:
                        temp_num.append(i)

            if not flag_not and len(temp_num) != 0:
                num_stack.append(temp_num)
            if not flag_not and len(temp_num) == 0:
                num_stack.append([_ for _ in range(len(num_list))])
        num_stack.reverse()
        return num_stack

Code for choosing the target symbol according to maximal score
MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py line 357

    def generate_tree_input(self, target, decoder_output, nums_stack_batch, num_start, unk):
        # when the decoder input is copied num but the num has two pos, chose the max
        target_input = copy.deepcopy(target)
        for i in range(len(target)):
            if target[i] == unk:
                num_stack = nums_stack_batch[i].pop()
                max_score = -float("1e12")
                for num in num_stack:
                    if decoder_output[i, num_start + num] > max_score:
                        target[i] = num + num_start
                        max_score = decoder_output[i, num_start + num]
            if target_input[i] >= num_start:
                target_input[i] = 0
        return torch.LongTensor(target), torch.LongTensor(target_input)

@indranilByjus
Copy link
Author

So I looked into the variables,
apparently the values are:
image

there are a few target tokens, who exceed beyond the specified num
I'd previously patched them:
target_t = torch.LongTensor([t if t<118 else 117 for t in target_t])

But its definitely screwing something during training.

@lijierui
Copy link

lijierui commented Nov 4, 2021

I encountered a similar problem.
I think one possible reason is that: the text of equation/question has a different format. Eg. you are using "x=1+2" as the equation but the data loader is expecting "1+2".

@LYH-YF LYH-YF added the bug Something isn't working label Nov 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants