Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

master 解码性能问题 #588

Open
jiangxiluning opened this issue Aug 21, 2023 · 1 comment
Open

master 解码性能问题 #588

jiangxiluning opened this issue Aug 21, 2023 · 1 comment
Assignees

Comments

@jiangxiluning
Copy link

jiangxiluning commented Aug 21, 2023

targets = ops.zeros((N, 1), ms.int32)

这段代码有两个性能问题:1. 每次targets 长度会变化,会触发图编译,导致infer 的时间变长。如果steps过长会导致很慢。不利于调试。2. probs 每次会append 一个大tensor,步长过长,会导致显存占用过大,会浪费显存。
可以改成这样:

            targets = ops.fill(ms.int32, (N, num_steps+1), self.padding_symbol)
            targets[:, 0] = 0 # <GO>
            probs = ops.zeros((N, num_steps, self.out_channels), dtype=inputs.dtype)

            for i in range(num_steps):
                target_mask = self._generate_target_mask(targets)
                probs_step = self._decode(inputs, targets, target_mask=target_mask)
                next_input = self.argmax(probs_step)
                targets[:,  i+1] = next_input[:, i]
                probs[:, i] = probs_step[:, i]

            probs = ops.softmax(probs, axis=-1)
            return probs
@panshaowu
Copy link
Collaborator

感谢您的反馈。我们会安排开发工程师进行测试后,合入您所提供的代码。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants