Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement bidirectional lstm #1035

Merged
merged 23 commits into from
Apr 26, 2024
Merged

Conversation

wcshds
Copy link
Contributor

@wcshds wcshds commented Dec 1, 2023

I need bidirectional lstm in CRNN model.

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Comment on lines 53 to 63
if self.bidirectional {
(input_gate_bw, forget_gate_bw, output_gate_bw, cell_gate_bw) = (
Some(new_gate()),
Some(new_gate()),
Some(new_gate()),
Some(new_gate()),
);
} else {
(input_gate_bw, forget_gate_bw, output_gate_bw, cell_gate_bw) =
(None, None, None, None);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if it's a boolean condition, I feel like a match might be easier to read, but using new_gate is a big win.

@@ -110,7 +124,7 @@ impl<B: Backend> Lstm<B> {
/// Parameters:
/// batched_input: The input tensor of shape [batch_size, sequence_length, input_size].
/// state: An optional tuple of tensors representing the initial cell state and hidden state.
/// Each state tensor has shape [batch_size, hidden_size].
/// Each state tensor has shape [num_directions, batch_size, hidden_size].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it would be better to have a BiDirectionalLstm module instead of one module that does both. The bi-directional module might just have two Lstm modules instead of duplicating all the gates. We could also make it more Rusty by using Kind types:

pub trait LstmKind<B: Backend> {
    fn forward(...) -> ...;
}
#[derive(Module)]
pub struct UniDirectional{
    input_gate: GateController<B>,
    forget_gate: GateController<B>,
    output_gate: GateController<B>,
    cell_gate: GateController<B>,
};

#[derive(Module)]
pub struct BiDirectional<B: Backend> {
   forward: Lstm<B, UniDirectional>,
   backward: Lstm<B, UniFirectional>,
}
pub struct Lstm<B: Backend, K: LstmKind<B> = UniDirectional<B>> {
   state: K,
}

impl<B: Backend, K: LstmKind<B>> Lstm<B, K> {
   pub fn forward(...) -> ... {
       self.kind.forward(...) // static dispatch to the right forward pass depending on the kind.
   }
}

impl<B: Backend> LstmKind<B> for UniDirectional<B> {
   fn forward(...) -> ... {
       // uni directional forward pass
   }
}
impl<B: Backend> LstmKind<B> for BiDirectional<B> {
   fn forward(...) -> ... {
       // bi directional forward pass
   }
}

This is equivalent to having two different modules for the uni directional and bi directional lstm but with syntax sugar so users can also use Lstm<B> for uni directional or Lstm<B, BiDirectional> when they want both directions. This is the same pattern that we use for the Tensor API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also make it backward compatible in term of code, but not in term of state (Recorder).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agelas what are your thoughts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nathanielsimard I agree that this pattern fits better with the overall approach that is usually taken. I think composing LSTM modules like this is a bit more elegant than having various toggles in the forward pass, and we can reuse most of the prior implementation that way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nathanielsimard The problem I have is that the forward() for UniDirectional and BiDirectional have different function signatures due to the inconsistency in the shape of the state. This seems not easy to resolve, so it's probably best to keep the LSTM implementation unchanged and let users implement bidirectional LSTM or multi-layer LSTM on their own.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how often bidirectional LSTM is used, but could we have an implementation totally separated from unidirectional LSTM, and just offer both? Because I do agree that messing with function signatures might get a bit bothersome

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern I'm proposing creates two different types, so there are no problems with the type signature. It's just a way for both types to use the LSTM nomenclature. It's probably easier to start with two different types, then we can add the pattern afterward if it makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nathanielsimard This pattern is a bit tricky for me, so I think it's best for me to create a new type for bidirectional LSTM. I sincerely hope you can improve it if possible. Thank you very much!

Some((cell_state, hidden_state)) => (cell_state, hidden_state),
let [batch_size, seq_length, _] = batched_input.shape().dims;
let mut batched_cell_state =
Tensor::zeros_device([batch_size, seq_length, self.d_hidden * num_directions], &device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nathanielsimard Is this necessary? The implementation is already tied to a device vis-à-vis the backend right?

Copy link
Contributor Author

@wcshds wcshds Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, Module::fork can only fork the tensor in the module struct field to the given device. So, when creating a new Tensor in the forward(), it is necessary to specify the device?

By the way, now I feel that bidirectional LSTM may not be a common requirement for everyone. If someone needs bidirectional LSTM or multi-layer LSTM, it's probably best for them to implement it themselves. Additionally, keeping the implementation of Lstm unchanged is essential to avoid disrupting compatibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device is necessary when creating a new tensor; otherwise, the new tensor will be created on the default device, but not necessarily on the same device as the module.

Copy link

codecov bot commented Dec 20, 2023

Codecov Report

Attention: Patch coverage is 99.45504% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 86.51%. Comparing base (2f294c5) to head (56f6d7c).
Report is 3 commits behind head on main.

Files Patch % Lines
crates/burn-core/src/nn/rnn/lstm.rs 99.45% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1035      +/-   ##
==========================================
+ Coverage   86.41%   86.51%   +0.09%     
==========================================
  Files         696      696              
  Lines       81131    81499     +368     
==========================================
+ Hits        70112    70508     +396     
+ Misses      11019    10991      -28     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 212 to 219
input_gate: GateController<B>,
forget_gate: GateController<B>,
output_gate: GateController<B>,
cell_gate: GateController<B>,
input_gate_reverse: GateController<B>,
forget_gate_reverse: GateController<B>,
output_gate_reverse: GateController<B>,
cell_gate_reverse: GateController<B>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it wasn't possible to use two Lstm modules here?

@wcshds wcshds reopened this Dec 21, 2023
@wcshds wcshds marked this pull request as draft January 12, 2024 10:32
@wcshds
Copy link
Contributor Author

wcshds commented Jan 12, 2024

I think the current implementation of bidirectional lstm can work during inference, but the implementation cannot update the parameters of the model in backward propagation due to #1098. Using Tensor::cat instead of Tensor::slice_assign can solve this problem, but it seems to have some negative performance impact, so I think it's best to wait for the Tensor::slice_assign bug to be resolved for now.

@antimora antimora added the enhancement Enhance existing features label Jan 31, 2024
@antimora antimora added the stale The issue or pr has been open for too long label Feb 24, 2024
@antimora
Copy link
Collaborator

Closing this ticket and linking to this ticket: #1537. So someone else can pick up.

@antimora antimora closed this Mar 26, 2024
@wcshds
Copy link
Contributor Author

wcshds commented Apr 16, 2024

@antimora Could you reopen this pull request? The issue regarding Autodiff has been resolved in #1575, and I think it's time to proceed with implementing the Bidirectional LSTM.

@antimora antimora reopened this Apr 16, 2024
@wcshds wcshds marked this pull request as ready for review April 16, 2024 14:17
@wcshds
Copy link
Contributor Author

wcshds commented Apr 16, 2024

Wgpu test failed, but I don't know why...

@antimora antimora requested review from louisfd and agelas April 16, 2024 16:32
@agelas
Copy link
Contributor

agelas commented Apr 16, 2024

What was the error in wgpu?

@agelas
Copy link
Contributor

agelas commented Apr 16, 2024

Hm ok well on ubuntu-22.04 everything seems fine when I run the tests locally. Given that ubuntu and windows work, maybe it hints at possible compatibility issues or differences in how Metal handles resource management compared to Vulkan (ie Linux) or DirectX (Windows)?

@github-actions github-actions bot removed the stale The issue or pr has been open for too long label Apr 17, 2024
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it! I have a few comments, but nothing major. Would also like @laggui to review this.

Comment on lines +44 to +52
let new_gate = || {
GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
device,
)
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yesss!

/// ## Returns:
/// A tuple of tensors, where the first tensor represents the cell states and
/// the second tensor represents the hidden states for each sequence element.
/// Both output tensors have the shape `[batch_size, sequence_length, hidden_size]`.
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventhough that's a breaking change, I think it might be beneficial to create a type for the state.

pub struct LstmState {
    pub cell: Tensor<B, 2>,
    pub hidden: Tensor<B, 2>,
}

We can remove the optional and implement Default instead. It also gives us a space to document what each element of the state is used for.

We could also do the same for the return type, where it's easy to make mistake (chosing the wrong returned tensor).

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite like that idea, especially for the return type!

Copy link
Contributor Author

@wcshds wcshds Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If breaking changes are acceptable, I'd like to follow PyTorch's return values output, (h_n, c_n).

image

In that case, we can return (output, state). For BiLSTM, output is a concatenated tensor, and the cell state and hidden state can be provided in state to the user without modification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you agree to the modifications according to the PyTorch outputs, batched_cell_state is unnecessary. For the state, we only need to return the hidden state and cell state of the last time step.

Comment on lines 104 to 106
let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden], device);
let mut batched_hidden_state =
Tensor::zeros([batch_size, seq_length, self.d_hidden], device);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not linked to the current change, but is zero necessary or empty would be enough here?

input_biases: [f32; D1],
hidden_weights: [[f32; D1]; D1],
hidden_biases: [f32; D1],
device: &<TestBackend as Backend>::Device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a detail, but we can use &burn_tensor::Device<TestBackend> instead of this notation, just a bit prettier, especially when importing burn_tensor::Device.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great rewrite to add the BiLstm! I don't see any issues 🙂 Just a minor comment regarding the state as pointed out by Nath.

/// ## Returns:
/// A tuple of tensors, where the first tensor represents the cell states and
/// the second tensor represents the hidden states for each sequence element.
/// Both output tensors have the shape `[batch_size, sequence_length, hidden_size]`.
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite like that idea, especially for the return type!

@wcshds
Copy link
Contributor Author

wcshds commented Apr 18, 2024

The outputs of both Lstm and BiLstm are now aligned with PyTorch.

The script I used to generate tests for Bidirectional LSTM can be found here.

The tests for the wgpu backend still failed, possibly due to some data race issues in the wgpu backend?

@wcshds
Copy link
Contributor Author

wcshds commented Apr 18, 2024

@nathanielsimard @louisfd I changed the batch size to 2, and the test passed. This seems to be a bug with the Wgpu backend, so I've opened an issue #1656.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more comments from my side, great job! We will investigate the wgpu test problem, but it's quite unreleated.

@antimora antimora requested a review from laggui April 26, 2024 16:50
@antimora
Copy link
Collaborator

We will merge it, once we have an approval from @laggui

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job! 😄 I like the addition of LstmState and the returned output.

Only a few minor changes, and then we can merge 🎉

Comment on lines 81 to 82
/// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state,
/// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size]`.
Copy link
Member

@laggui laggui Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we document the return types, I think we can simplify the docstring here to:

Applies the forward pass on the input tensor. This LSTM implementation returns the state for each element in a sequence (i.e., across seq_length) and a final state.

(also removed the ambiguity with "returns hidden state", the state contains both the hidden and cell state).

Comment on lines 215 to 217
/// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation
/// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state,
/// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size * 2]`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as with the Lstm forward docstring. We can simplify to:

Applies the forward pass on the input tensor. This Bidirectional LSTM implementation returns the state for each element in a sequence (i.e., across seq_length) and a final state.

.select(0, Tensor::arange(0..1, &device))
.squeeze(0);
cell_state
// let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead comment, we can remove 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All done.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 🚀

@antimora antimora merged commit b387829 into tracel-ai:main Apr 26, 2024
14 checks passed
@wcshds wcshds deleted the bidirectional-lstm branch April 27, 2024 04:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants