-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bert4torch版本0.2.8升级到0.3.4问题 #157
Comments
您好,这个是之前改版时候,example没有更改过来,应该按照下述这样修改一下就可以了,也可以升级到最新的0.3.7,最新版本不需要convert权重,仅需使用bert4torch_config.json就可以加载了 def forward(self, outputs, y_true):
y_pred = outputs[-1]
y_pred = y_pred.reshape(-1, y_pred.shape[-1])
return super().forward(y_pred, y_true)
@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
res = model.decoder.predict([output_ids] + inputs)
return res[-1][:, -1, :] if isinstance(res, list) else res[:, -1, :] # 保留最后一位 |
问题修复了,感谢~! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
你好 我原本的bert4torch版本是0.2.8执行task_seq2seq_autotitle_csl_mt5等一些类似模型没有问题,但是版本升级到0.3.4发生问题
在下面这个方法中outputs值返回2个值
class CrossEntropyLoss(nn.CrossEntropyLoss):
def init(self, **kwargs):
super().init(**kwargs)
如果去掉一个的话 在下面这部分的return地方会报错。 请问要如何解决
class AutoTitle(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
# inputs中包含了[decoder_ids, encoder_hidden_state, encoder_attention_mask]
# 保留最后一位
return model.decoder.predict([output_ids] + inputs)[-1][:, -1, :]
The text was updated successfully, but these errors were encountered: