Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correlation Coefficient Calculation #36

Open
adamcatto opened this issue Jan 22, 2024 · 3 comments
Open

Correlation Coefficient Calculation #36

adamcatto opened this issue Jan 22, 2024 · 3 comments

Comments

@adamcatto
Copy link

adamcatto commented Jan 22, 2024

I ran the test_pretrained.py script to calculate the correlation coefficient on a validation sample, and got 0.5963 as expected. However, when I inspected the target and predictions, the shapes were each (896, 5313), i.e. missing the batch dimension. The pearson_corr_coef function computes similarity over dim=1, so the calculated number 0.5963 is actually a measure of correlation over the different cell lines, rather than over the track positions per cell line. When you unsqueeze the batch dimension, then the correlation is calculated over track positions, and yields a value of 0.4721. This is the way that Enformer reports correlation, so does it make sense to update the README and test_pretrained.py with this procedure? Also, were the reported correlation coefficients 0.625 and 0.65 on the train/test sets calculated on samples with missing batch dimension? If so, a recalculation would be necessary. Am I missing something?

Here is the modified test_pretrained.py script I have used:

import torch
from enformer_pytorch import Enformer

enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').cuda()
enformer.eval()

data = torch.load('./data/test-sample.pt')
seq, target = data['sequence'].cuda(), data['target'].cuda()
print(seq.shape) # torch.Size([131072, 4])
print(target.shape) # torch.Size([896, 5313])
seq = seq.unsqueeze(0)
target = target.unsqueeze(0)

# Note: you will find prediction shape is also `torch.Size([896, 5313])`.

with torch.no_grad():
    corr_coef = enformer(
        seq,
        target = target,
        return_corr_coef = True,
        head = 'human'
    )

print(corr_coef) # tensor([0.4721], device='cuda:0')
assert corr_coef > 0.1
@jstjohn
Copy link
Contributor

jstjohn commented Feb 29, 2024

Hi Adam! Forgive my slow reply. Please dig into the notebook https://github.com/lucidrains/enformer-pytorch/blob/main/evaluate_enformer_pytorch_correlation.ipynb to see how I got to those numbers. I did not use the function you're using for computing correlation.

@jstjohn
Copy link
Contributor

jstjohn commented Feb 29, 2024

So now to your question, how is correlation calculated in forward? I didn't write that part. If you look at the code in forward, and you pass one hot sequences without batch, they will get batch added:

    no_batch = x.ndim == 2

    if no_batch:
        x = rearrange(x, '... -> () ...')

Now the other part is more interesting. I don't see batch added to the target (looking at the code on my iPhone). You'd have to look at how the correlation computation function is called in forward of Enformer. Maybe I missed the line where that happens, or maybe it works without doing that?

@jstjohn
Copy link
Contributor

jstjohn commented Feb 29, 2024

Here's the code for the correlation function:

def pearson_corr_coef(x, y, dim = 1, reduce_dims = (-1,)):
x_centered = x - x.mean(dim = dim, keepdim = True)
y_centered = y - y.mean(dim = dim, keepdim = True)
return F.cosine_similarity(x_centered, y_centered, dim = dim).mean(dim = reduce_dims)

So dim=1 in this case points to the last dimension in the target since it doesn't have the batch on it, I think, but the first dimension in the prediction?

Seems worth digging into! Keep in mind this affects the sanity check return but not how correlation was verified. Again see the notebook I posted which calculates this independently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants