Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporating copy mechanism in decoder #309

Open
roemmele opened this issue Apr 22, 2020 · 2 comments
Open

Incorporating copy mechanism in decoder #309

roemmele opened this issue Apr 22, 2020 · 2 comments
Labels
enhancement New feature or request topic: modules Issue about built-in Texar modules

Comments

@roemmele
Copy link

I'm really enjoying this library, thanks for your work. Just curious, are there any plans to implement some sort of copying mechanism for decoding, e.g. CopyNet (https://arxiv.org/abs/1603.06393)?

@huzecong
Copy link
Collaborator

Thank you for using Texar! Unfortunately we don't have current plans to add built-in support for copying mechanism.

That said, given what's already in the library, it shouldn't be too difficult to implement the copying mechanism based on AttentionRNNDecoder; you would only need to modify the initialize(), step(), and next_inputs() methods:

def initialize( # type: ignore
self, helper: Helper,
inputs: Optional[torch.Tensor],
sequence_length: Optional[torch.LongTensor],
initial_state: Optional[MaybeList[MaybeTuple[torch.Tensor]]]) -> \
Tuple[torch.ByteTensor, torch.Tensor,
Optional[AttentionWrapperState]]:
initial_finished, initial_inputs = helper.initialize(
self.embed_tokens, inputs, sequence_length)
if initial_state is None:
state = None
else:
tensor = utils.get_first_in_structure(initial_state)
assert tensor is not None
tensor: torch.Tensor
state = self._cell.zero_state(batch_size=tensor.size(0))
state = state._replace(cell_state=initial_state)
return initial_finished, initial_inputs, state
def step(self, helper: Helper, time: int, inputs: torch.Tensor,
state: Optional[AttentionWrapperState]) -> \
Tuple[AttentionRNNDecoderOutput, AttentionWrapperState]:
wrapper_outputs, wrapper_state = self._cell(
inputs, state, self.memory, self.memory_sequence_length)
# Essentially the same as in BasicRNNDecoder.step()
logits = self._output_layer(wrapper_outputs)
sample_ids = helper.sample(time=time, outputs=logits)
attention_scores = wrapper_state.alignments
attention_context = wrapper_state.attention
outputs = AttentionRNNDecoderOutput(
logits, sample_ids, wrapper_outputs,
attention_scores, attention_context)
next_state = wrapper_state
return outputs, next_state
def next_inputs(self, helper: Helper, time: int,
outputs: AttentionRNNDecoderOutput) -> \
Tuple[torch.Tensor, torch.ByteTensor]:
finished, next_inputs = helper.next_inputs(
self.embed_tokens, time, outputs.logits, outputs.sample_id)
return next_inputs, finished

If you have a working implementation, you're more than welcome to create a pull request and contribute to the library!

@huzecong huzecong added enhancement New feature or request topic: modules Issue about built-in Texar modules labels Apr 23, 2020
@roemmele
Copy link
Author

Ok, I may try to do that at some point. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request topic: modules Issue about built-in Texar modules
Projects
None yet
Development

No branches or pull requests

2 participants