Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolation error on computing average precision #1103

Open
1 of 2 tasks
illian01 opened this issue Apr 9, 2024 · 3 comments
Open
1 of 2 tasks

Interpolation error on computing average precision #1103

illian01 opened this issue Apr 9, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@illian01
Copy link

illian01 commented Apr 9, 2024

Search before asking

  • I have searched the Supervision issues and found no similar bug report.

Bug

I have precision and recall curve as below.

precision = np.array([1.0, 1.0, 1.0, 1.0, 0.99875156, 0.998002, 0.99666944, 0.9971449, 0.99625234, 0.99555802, 0.994003, 0.9918219, 0.99083715, 0.99000384, 0.98786148, 0.98600467, 0.98344267, 0.98265216, 0.98083866, 0.97947908, 0.97525619, 0.97048322, 0.96682572, 0.96305151, 0.95875859, 0.955009, 0.95154778, 0.94815775, 0.94215319, 0.93880365, 0.93284453, 0.92710853, 0.9206374, 0.91410392, 0.90604323, 0.89815741, 0.89070962, 0.88096203, 0.87304302, 0.86565825, 0.85939258, 0.84989635, 0.84192358, 0.833624, 0.82422452, 0.81490945, 0.80480383, 0.79629827, 0.78731382, 0.77798184, 0.76882312, 0.76002353, 0.75040861, 0.74181681, 0.7328951, 0.72293428, 0.71431122, 0.70528901, 0.69700888, 0.68824676, 0.67994334, 0.67183018, 0.66446254, 0.65629712, 0.64869932, 0.64125836, 0.63434588, 0.62696814, 0.61995442, 0.61256431, 0.60617099, 0.59918316, 0.59245886, 0.58591877, 0.57982569, 0.57362842, 0.56759424, 0.56145705, 0.55528492, 0.54958547, 0.5440285, 0.53811493, 0.5325285, 0.52695621, 0.52151658, 0.51626375, 0.51136562, 0.50652261, 0.50133515, 0.49632043, 0.4914727, 0.48689632, 0.48220205, 0.47777001, 0.47305994, 0.46897532, 0.46487162, 0.46038864, 0.45635427, 0.45230039, 0.44817759, 0.44438394, 0.44027254, 0.43658075, 0.43271958, 0.42912242, 0.42531013, 0.42184945, 0.41822138, 0.41461401, 0.41120858, 0.40781947, 0.40458015, 0.40130968, 0.39827201, 0.39511326, 0.39226757, 0.38925687, 0.38646668, 0.38355531, 0.38073414, 0.37771166, 0.3748207, 0.37197675, 0.36909802, 0.36642534, 0.36387445, 0.36136373, 0.35873599, 0.3562265, 0.3536787, 0.3511698, 0.34869891, 0.34626518, 0.34390508, 0.3414318, 0.33906842, 0.33681253, 0.33469802, 0.33250602, 0.33023821, 0.32810893, 0.32615049, 0.32397469, 0.32203743, 0.31998897, 0.31824253, 0.31621373, 0.31421236, 0.31217073, 0.31032299, 0.30846661, 0.30663465, 0.30479396, 0.30304211, 0.30131286, 0.29947758, 0.29772937, 0.29587671, 0.29420458, 0.29245961, 0.29076737, 0.28900343, 0.2874145, 0.28566202, 0.2841126, 0.28243125, 0.28085985, 0.27930716, 0.27783202, 0.27625658, 0.27469957, 0.27324787, 0.27181295, 0.27033706, 0.26882089, 0.26735036, 0.26595294, 0.26459931, 0.2632608, 0.26190939, 0.26071103, 0.25941595, 0.25816235, 0.25675933, 0.25550661, 0.25413295, 0.25280073, 0.25150927, 0.25023148, 0.24904608, 0.24774221, 0.246504, 0.24530453, 0.24419474, 0.24301941, 0.24185608, 0.24062841, 0.2394889, 0.23833572, 0.23724407, 0.23611353, 0.23499418, 0.23386123, 0.23278841, 0.23167728, 0.23064974, 0.22958383, 0.2285522, 0.2274826, 0.2264708, 0.2253975, 0.22440508, 0.22342199, 0.22244807, 0.22145997, 0.22050415, 0.21951107, 0.21857297, 0.21757494, 0.21663144, 0.2157191, 0.2148375, 0.21396381, 0.21309792, 0.21219528, 0.21132276, 0.21041387, 0.20951295, 0.20866357, 0.20777809, 0.20687864])
recall = np.array([9.088e-05, 0.01826609, 0.03644129, 0.0546165, 0.07270084, 0.09078517, 0.10877863, 0.12695383, 0.14494729, 0.16294075, 0.18075245, 0.19838241, 0.21619411, 0.23400582, 0.25145402, 0.26890222, 0.28607779, 0.30370774, 0.32097419, 0.33833152, 0.35459833, 0.37050164, 0.38667757, 0.40267176, 0.41830244, 0.43402399, 0.44974555, 0.46537623, 0.47955289, 0.49491094, 0.5087241, 0.52244638, 0.53553253, 0.54834606, 0.55997819, 0.57142857, 0.58287895, 0.59251181, 0.60305344, 0.61368593, 0.62486369, 0.63340603, 0.64276627, 0.65158124, 0.65921483, 0.66657579, 0.67293711, 0.68029807, 0.68693202, 0.69292984, 0.69874591, 0.70456198, 0.70928753, 0.71464922, 0.71937477, 0.72273719, 0.72709924, 0.73073428, 0.7348237, 0.73809524, 0.74154853, 0.74491094, 0.74881861, 0.75154489, 0.75463468, 0.75763359, 0.760996, 0.76354053, 0.76626681, 0.76826609, 0.77126499, 0.77326427, 0.77535442, 0.77744457, 0.77989822, 0.78198837, 0.78407852, 0.78580516, 0.78725918, 0.78916758, 0.79107597, 0.79225736, 0.79371138, 0.79498364, 0.79625591, 0.79761905, 0.79934569, 0.80098146, 0.80189022, 0.80288986, 0.80398037, 0.80534351, 0.80634315, 0.80761541, 0.80825154, 0.80979644, 0.81115958, 0.81170483, 0.81288622, 0.81388586, 0.81461287, 0.81579426, 0.81624864, 0.81733915, 0.81797528, 0.81897492, 0.8194293, 0.82042893, 0.82097419, 0.82142857, 0.82215558, 0.82279171, 0.8236096, 0.82424573, 0.82524537, 0.8258815, 0.82706289, 0.82778989, 0.82888041, 0.82960742, 0.8304253, 0.83069793, 0.83115231, 0.83160669, 0.83187932, 0.83251545, 0.83333333, 0.83415122, 0.8346056, 0.83524173, 0.83569611, 0.83615049, 0.83660487, 0.83705925, 0.83760451, 0.83778626, 0.83814976, 0.83869502, 0.8395129, 0.84005816, 0.84033079, 0.84087605, 0.84178481, 0.84205743, 0.84287532, 0.8433297, 0.84451109, 0.84487459, 0.8452381, 0.84541985, 0.84605598, 0.84660124, 0.84714649, 0.84760087, 0.848237, 0.84887314, 0.84914577, 0.84960015, 0.84969102, 0.85023628, 0.85050891, 0.85087241, 0.85096329, 0.85150854, 0.85150854, 0.8520538, 0.85214467, 0.85250818, 0.85287168, 0.85341694, 0.85359869, 0.85378044, 0.85423482, 0.8546892, 0.85496183, 0.85505271, 0.85523446, 0.85559796, 0.85605234, 0.85650672, 0.85687023, 0.85768811, 0.85814249, 0.85868775, 0.85868775, 0.85914213, 0.85914213, 0.85923301, 0.85941476, 0.85959651, 0.86005089, 0.86005089, 0.86023264, 0.86050527, 0.86105053, 0.86132316, 0.86159578, 0.86159578, 0.86186841, 0.86205016, 0.86241367, 0.86259542, 0.86277717, 0.86286805, 0.86314068, 0.86323155, 0.86359506, 0.86377681, 0.86404944, 0.86414031, 0.86441294, 0.86441294, 0.86468557, 0.8649582, 0.86523083, 0.86541258, 0.86568521, 0.86577608, 0.86604871, 0.86604871, 0.86623046, 0.86650309, 0.86686659, 0.8672301, 0.8675936, 0.86777535, 0.86804798, 0.86813886, 0.86822973, 0.86850236, 0.86859324, 0.86859324])

So, I can plot PR curve as below.

plt.clf()
plt.plot(recall, precision)
plt.show()

test

However, according to compute_average_precision method, always pad 0.0 and 1.0 to precision and recall vectors.

def compute_average_precision(recall: np.ndarray, precision: np.ndarray) -> float:
"""
Compute the average precision using 101-point interpolation (COCO), given
the recall and precision curves.
Args:
recall (np.ndarray): The recall curve.
precision (np.ndarray): The precision curve.
Returns:
float: Average precision.
"""
extended_recall = np.concatenate(([0.0], recall, [1.0]))
extended_precision = np.concatenate(([1.0], precision, [0.0]))
max_accumulated_precision = np.flip(
np.maximum.accumulate(np.flip(extended_precision))
)
interpolated_recall_levels = np.linspace(0, 1, 101)
interpolated_precision = np.interp(
interpolated_recall_levels, extended_recall, max_accumulated_precision
)
average_precision = np.trapz(interpolated_precision, interpolated_recall_levels)
return average_precision

I think these extensions are dangerous and makes pr curve inaccurate. Below is curve of interpolated_precision and interpolated_recall_levels.

plt.clf()
plt.plot(interpolated_recall_levels, interpolated_precision)
plt.show()

test

The bent part of right bottom in the plot seems unnatural and effects on the accuracy of the average precision calculation.

Environment

No response

Minimal Reproducible Example

import numpy as np
import matplotlib.pyplot as plt

precision = np.array([1.0, 1.0, 1.0, 1.0, 0.99875156, 0.998002, 0.99666944, 0.9971449, 0.99625234, 0.99555802, 0.994003, 0.9918219, 0.99083715, 0.99000384, 0.98786148, 0.98600467, 0.98344267, 0.98265216, 0.98083866, 0.97947908, 0.97525619, 0.97048322, 0.96682572, 0.96305151, 0.95875859, 0.955009, 0.95154778, 0.94815775, 0.94215319, 0.93880365, 0.93284453, 0.92710853, 0.9206374, 0.91410392, 0.90604323, 0.89815741, 0.89070962, 0.88096203, 0.87304302, 0.86565825, 0.85939258, 0.84989635, 0.84192358, 0.833624, 0.82422452, 0.81490945, 0.80480383, 0.79629827, 0.78731382, 0.77798184, 0.76882312, 0.76002353, 0.75040861, 0.74181681, 0.7328951, 0.72293428, 0.71431122, 0.70528901, 0.69700888, 0.68824676, 0.67994334, 0.67183018, 0.66446254, 0.65629712, 0.64869932, 0.64125836, 0.63434588, 0.62696814, 0.61995442, 0.61256431, 0.60617099, 0.59918316, 0.59245886, 0.58591877, 0.57982569, 0.57362842, 0.56759424, 0.56145705, 0.55528492, 0.54958547, 0.5440285, 0.53811493, 0.5325285, 0.52695621, 0.52151658, 0.51626375, 0.51136562, 0.50652261, 0.50133515, 0.49632043, 0.4914727, 0.48689632, 0.48220205, 0.47777001, 0.47305994, 0.46897532, 0.46487162, 0.46038864, 0.45635427, 0.45230039, 0.44817759, 0.44438394, 0.44027254, 0.43658075, 0.43271958, 0.42912242, 0.42531013, 0.42184945, 0.41822138, 0.41461401, 0.41120858, 0.40781947, 0.40458015, 0.40130968, 0.39827201, 0.39511326, 0.39226757, 0.38925687, 0.38646668, 0.38355531, 0.38073414, 0.37771166, 0.3748207, 0.37197675, 0.36909802, 0.36642534, 0.36387445, 0.36136373, 0.35873599, 0.3562265, 0.3536787, 0.3511698, 0.34869891, 0.34626518, 0.34390508, 0.3414318, 0.33906842, 0.33681253, 0.33469802, 0.33250602, 0.33023821, 0.32810893, 0.32615049, 0.32397469, 0.32203743, 0.31998897, 0.31824253, 0.31621373, 0.31421236, 0.31217073, 0.31032299, 0.30846661, 0.30663465, 0.30479396, 0.30304211, 0.30131286, 0.29947758, 0.29772937, 0.29587671, 0.29420458, 0.29245961, 0.29076737, 0.28900343, 0.2874145, 0.28566202, 0.2841126, 0.28243125, 0.28085985, 0.27930716, 0.27783202, 0.27625658, 0.27469957, 0.27324787, 0.27181295, 0.27033706, 0.26882089, 0.26735036, 0.26595294, 0.26459931, 0.2632608, 0.26190939, 0.26071103, 0.25941595, 0.25816235, 0.25675933, 0.25550661, 0.25413295, 0.25280073, 0.25150927, 0.25023148, 0.24904608, 0.24774221, 0.246504, 0.24530453, 0.24419474, 0.24301941, 0.24185608, 0.24062841, 0.2394889, 0.23833572, 0.23724407, 0.23611353, 0.23499418, 0.23386123, 0.23278841, 0.23167728, 0.23064974, 0.22958383, 0.2285522, 0.2274826, 0.2264708, 0.2253975, 0.22440508, 0.22342199, 0.22244807, 0.22145997, 0.22050415, 0.21951107, 0.21857297, 0.21757494, 0.21663144, 0.2157191, 0.2148375, 0.21396381, 0.21309792, 0.21219528, 0.21132276, 0.21041387, 0.20951295, 0.20866357, 0.20777809, 0.20687864])
recall = np.array([9.088e-05, 0.01826609, 0.03644129, 0.0546165, 0.07270084, 0.09078517, 0.10877863, 0.12695383, 0.14494729, 0.16294075, 0.18075245, 0.19838241, 0.21619411, 0.23400582, 0.25145402, 0.26890222, 0.28607779, 0.30370774, 0.32097419, 0.33833152, 0.35459833, 0.37050164, 0.38667757, 0.40267176, 0.41830244, 0.43402399, 0.44974555, 0.46537623, 0.47955289, 0.49491094, 0.5087241, 0.52244638, 0.53553253, 0.54834606, 0.55997819, 0.57142857, 0.58287895, 0.59251181, 0.60305344, 0.61368593, 0.62486369, 0.63340603, 0.64276627, 0.65158124, 0.65921483, 0.66657579, 0.67293711, 0.68029807, 0.68693202, 0.69292984, 0.69874591, 0.70456198, 0.70928753, 0.71464922, 0.71937477, 0.72273719, 0.72709924, 0.73073428, 0.7348237, 0.73809524, 0.74154853, 0.74491094, 0.74881861, 0.75154489, 0.75463468, 0.75763359, 0.760996, 0.76354053, 0.76626681, 0.76826609, 0.77126499, 0.77326427, 0.77535442, 0.77744457, 0.77989822, 0.78198837, 0.78407852, 0.78580516, 0.78725918, 0.78916758, 0.79107597, 0.79225736, 0.79371138, 0.79498364, 0.79625591, 0.79761905, 0.79934569, 0.80098146, 0.80189022, 0.80288986, 0.80398037, 0.80534351, 0.80634315, 0.80761541, 0.80825154, 0.80979644, 0.81115958, 0.81170483, 0.81288622, 0.81388586, 0.81461287, 0.81579426, 0.81624864, 0.81733915, 0.81797528, 0.81897492, 0.8194293, 0.82042893, 0.82097419, 0.82142857, 0.82215558, 0.82279171, 0.8236096, 0.82424573, 0.82524537, 0.8258815, 0.82706289, 0.82778989, 0.82888041, 0.82960742, 0.8304253, 0.83069793, 0.83115231, 0.83160669, 0.83187932, 0.83251545, 0.83333333, 0.83415122, 0.8346056, 0.83524173, 0.83569611, 0.83615049, 0.83660487, 0.83705925, 0.83760451, 0.83778626, 0.83814976, 0.83869502, 0.8395129, 0.84005816, 0.84033079, 0.84087605, 0.84178481, 0.84205743, 0.84287532, 0.8433297, 0.84451109, 0.84487459, 0.8452381, 0.84541985, 0.84605598, 0.84660124, 0.84714649, 0.84760087, 0.848237, 0.84887314, 0.84914577, 0.84960015, 0.84969102, 0.85023628, 0.85050891, 0.85087241, 0.85096329, 0.85150854, 0.85150854, 0.8520538, 0.85214467, 0.85250818, 0.85287168, 0.85341694, 0.85359869, 0.85378044, 0.85423482, 0.8546892, 0.85496183, 0.85505271, 0.85523446, 0.85559796, 0.85605234, 0.85650672, 0.85687023, 0.85768811, 0.85814249, 0.85868775, 0.85868775, 0.85914213, 0.85914213, 0.85923301, 0.85941476, 0.85959651, 0.86005089, 0.86005089, 0.86023264, 0.86050527, 0.86105053, 0.86132316, 0.86159578, 0.86159578, 0.86186841, 0.86205016, 0.86241367, 0.86259542, 0.86277717, 0.86286805, 0.86314068, 0.86323155, 0.86359506, 0.86377681, 0.86404944, 0.86414031, 0.86441294, 0.86441294, 0.86468557, 0.8649582, 0.86523083, 0.86541258, 0.86568521, 0.86577608, 0.86604871, 0.86604871, 0.86623046, 0.86650309, 0.86686659, 0.8672301, 0.8675936, 0.86777535, 0.86804798, 0.86813886, 0.86822973, 0.86850236, 0.86859324, 0.86859324])

plt.clf()
plt.plot(recall, precision)
plt.savefig('original_pr_curve.png')

def compute_average_precision(recall: np.ndarray, precision: np.ndarray) -> float:
    """
    Compute the average precision using 101-point interpolation (COCO), given
        the recall and precision curves.

    Args:
        recall (np.ndarray): The recall curve.
        precision (np.ndarray): The precision curve.

    Returns:
        float: Average precision.
    """
    extended_recall = np.concatenate(([0.0], recall, [1.0]))
    extended_precision = np.concatenate(([1.0], precision, [0.0]))
    max_accumulated_precision = np.flip(
        np.maximum.accumulate(np.flip(extended_precision))
    )
    interpolated_recall_levels = np.linspace(0, 1, 101)
    interpolated_precision = np.interp(
        interpolated_recall_levels, extended_recall, max_accumulated_precision
    )
    average_precision = np.trapz(interpolated_precision, interpolated_recall_levels)
    #return average_precision
    return interpolated_precision, interpolated_recall_levels

interpolated_precision, interpolated_recall_levels = compute_average_precision(recall, precision)

plt.clf()
plt.plot(interpolated_recall_levels, interpolated_precision)
plt.savefig('interpolated_pr_curve.png')

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!
@illian01 illian01 added the bug Something isn't working label Apr 9, 2024
@LinasKo
Copy link
Collaborator

LinasKo commented Apr 10, 2024

Hey @illian01 👋

Thank you for reporting the issue!
You're right - it definitely looks wrong. We'll need to rethink how we do it a bit.

If anyone in the community stumbles upon this - would you like to help us out? 🙂

@Griffin-Sullivan
Copy link
Contributor

Hey I was just playing around with this and had a couple questions:

  1. Why do we extend the recall and precision between 0.0 and 1.0? Is there an assumption you will always have a value very close to these numbers and that's the issue here?
  2. Could we not just use the minimum and maximum recall for np.linspace(0, 1, 101)?

@Griffin-Sullivan
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants