-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE]: Add option to make component order in multicomponent not matter #806
Comments
We could have options for both invariance and equivariance, to cover cases where the user wants the prediction to be the same regardless of order, as well as the cases where they want the prediction to be the negative if the order is switched. |
You would probably want to achieve this using some module: class Fusion(nn.Module):
...
def forward(self, *Xs: Tensor) -> Tensor:
"""Fuse the input tensors into a single output tensor""" where you would have different implementations for the operators: re: equivariance- @kevingreenman, where are you getting the negative from, i.e., why would permuting the order of the inputs negate the original output? Given some embedding matrix of multiple molecules |
@davidegraff good point. I'm referring to the specific case of |
But in the case of permuting an input of |
Could taking the difference or cross product between the embeddings (instead of a sum or average in the case of invariance) achieve the antisymmetry? For context, I'm trying to think of whether there's a way for us to satisfy some of the "guarantees" from approaches like DeepDelta (which is just a wrapper that calls chemprop with |
Thinking about it more, I realize that doesn't make sense. Those things would ensure the encoding is antisymmetric, but that would give no guarantees about the output after the encoding goes through the feedforward network. |
Yeah exactly, there’s nothing in the DeepDelta architecture that guarantees antisymmetry. They’re just relying on typical data augmentation to approximate an antisymmetric function. FWIW, I don’t see why we would want this in Chemprop. The measured property should be invariant to the order of our components (i.e., symmetric), as opposed to pairwise property differences, which should be antisymmetric. |
Not a straightforward problem to solve, been thinking about it a lot lately. If you sum the fingerprints before FFN you lose some resolution on the data. What if there's a benefit to seeing how dissimilar two molecules are, now you don't really know. Talked with @cbilodeau2 about how to deal with this a while ago. The solution she landed on is in this linked paper, averaging the FFN output. It's not easily extensible to multicomponent systems. But if you are willing to limit to 2 components I think it's a very good approach. |
There are a variety equivariant pooling techniques to choose from (notably, not Footnotes |
Thanks for the reference. This does seem to be an instance of Janossy. Glad to have a formal name for it now. Yes you can use more combinations to average is fine to extend it, but I'd be a bit stumped how to code that practically into Chemprop as other than a for loop with hardcoded options for 2,3,4 components. |
re: Python implementation class Fusion(nn.Module):
def forward(self, *Xs: Tensor) -> Tensor:
"""
Parameters
-----------
*Xs: tuple[Tensor, ...]
a tuple of tensors of shape `n x d` containing the aggregated feature representations,
where `n` is the batch size and `d` is the embedding dimensionality
Returns
-------
Tensor
A tensor of shape `n x *` containing the aggregated combined feature representations
"""
class JanossyFusion(Fusion):
def __init__(self, mlp: nn.Module, k: int) -> None:
self.mlp = mlp
self.k = k
def forward(self, *Xs: Tensor) -> Tensor:
Zs = []
for pi in sample(list(permutations(len(Xs)), self.k)):
pi = torch.randperm(len(Xs))
X = torch.cat([Xs[i] for i in pi], dim=1)
Z = self.mlp(X)
Zs.append(Z)
Z = torch.stack(Zs, 0).mean(0)
return Z |
I used Janossy pooling for bond property prediction, as each bond has two direct bonds. The implementation generally looks good to me. However, I think the code for the for loop is incorrect. The input for the permutations function should be a sequence, and
can be modified as
or
The second method cannot guarantee that the samples will not be repeated, so the first one looks better. |
There was in error in the original code block (the perils of writing code in markdown on your phone!) The def forward(self, *Xs: Tensor) -> Tensor:
Zs = []
all_permutations = list(permutations(len(Xs)))
for pi in sample(all_permutations, self.k):
X = torch.cat([Xs[i] for i in pi], dim=1)
Z = self.mlp(X)
Zs.append(Z)
Z = torch.stack(Zs, 0).mean(0)
return Z where |
At the MLPDS meeting someone brought up that in multicomponent the order of components currently matters because the learned representations are concatenated. Could we add an option to make the architecture order invariant? The only way I can think of is summing/averaging the learning representations. I started some work on this for my solvent mixtures project.
In any event, there aren't many cases of multicomponent datasets where the order of components doesn't matter. Usually it is solute + solvent, or rxn + solvent. Posting an issue in case others have a similar ideas. This issue shouldn't get a milestone and can be closed if there is no discussion on it.
The text was updated successfully, but these errors were encountered: