Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART、T5和GPT2的模型输入问题 #57

Open
YunweiDai opened this issue Nov 24, 2023 · 0 comments
Open

BART、T5和GPT2的模型输入问题 #57

YunweiDai opened this issue Nov 24, 2023 · 0 comments
Labels
question Further information is requested

Comments

@YunweiDai
Copy link

大佬您好!我在学习您代码的时候,对BART、T5和GPT2在
outputs = model(**inputs)
这一步的这个inputs产生了一些疑问。

第一个问题是,对于BART我们在得到

source_ids, source_mask, y = (
                batch["source_ids"],
                batch["source_mask"],
                batch["target_ids"],
            )

后,需要进行以下操作

y_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100

来得到inputs的"decoder_input_ids"和"labels":

inputs = {
                "input_ids": source_ids.to(device),
                "attention_mask": source_mask.to(device),
                "decoder_input_ids": y_ids.to(device), 
                "labels": labels.to(device),
            }

但对于T5,我们只需
labels[labels == self.tokenizer.pad_token_id] = -100
也就是BART中的第3个操作,即可得到inputs:

inputs = {
                "input_ids": input_ids,
                "attention_mask": attention_mask, 
                "labels": labels,
            }

我检查了BART和T5的模型结构以及它们的Huggingface文档,发现对于BART并无明确的说明,但T5的文档中提到无需使用"decoder_input_ids"或对labels进行截取(我指的是BART中的labels = y[:, 1:].clone()),直接输入encode后的input_ids、attention_mask和labels即可,这与您的T5代码一致。
我也检查了BART和T5的源码modeling_bart.py和modeling_t5.py,发现在不使用"decoder_input_ids"情况下两者的操作几乎是一致的,但实际运行过程中发现若BART不使用"decoder_input_ids"且不对labels进行截取,直接

inputs = {
                "input_ids": source_ids.to(device),
                "attention_mask": source_mask.to(device),
                "labels": y.to(device),
            }

后得到的outputs与您的BART代码不同(已保证batch完全一致),请问这是什么原因呢?BART实际使用时用哪个比较好呢?

第二个问题由三个小问题组成,
第1个是对于GPT2做对联生成任务时,已知上联是src,下联是trg,那么将上联和下联拼在一起生成input_ids时除了首尾的[CLS]和[SEP]外还需要在src和trg中间加上[SEP]吗?ChatGPT的说法是不需要的,但您的任务中加上了,有些疑惑。
第2个是在计算loss时,是只计算trg部分的loss,还是src+trg的loss都计算比较好?
第3个是对于GPT2的forward:outputs = model(inputs, labels=labels),不需要输入attention_mask了吗?即使我们用pad_sequence将inputs填充到每个batch的最大长度?

理解不是很深,麻烦您了!

@YunweiDai YunweiDai added the question Further information is requested label Nov 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant