Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from entrypoints to importlib.metadata #792

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions papermill/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from functools import wraps

import dateutil
import entrypoints

from .clientwrap import PapermillNotebookClient
from .exceptions import PapermillException
from .iorw import write_ipynb
from .log import logger
from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args
from .utils import get_entrypoints_group, merge_kwargs, nb_kernel_name, nb_language, remove_args


class PapermillEngines:
Expand All @@ -33,7 +32,7 @@ def register_entry_points(self):

Load handlers provided by other packages
"""
for entrypoint in entrypoints.get_group_all("papermill.engine"):
for entrypoint in get_entrypoints_group("papermill.engine"):
self.register(entrypoint.name, entrypoint.load())

def get_engine(self, name=None):
Expand Down
5 changes: 2 additions & 3 deletions papermill/iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from contextlib import contextmanager

import entrypoints
import nbformat
import requests
import yaml
Expand All @@ -18,7 +17,7 @@
missing_environment_variable_generator,
)
from .log import logger
from .utils import chdir
from .utils import chdir, get_entrypoints_group
from .version import version as __version__

try:
Expand Down Expand Up @@ -116,7 +115,7 @@ def register(self, scheme, handler):

def register_entry_points(self):
# Load handlers provided by other packages
for entrypoint in entrypoints.get_group_all("papermill.io"):
for entrypoint in get_entrypoints_group("papermill.io"):
self.register(entrypoint.name, entrypoint.load())

def get_handler(self, path, extensions=None):
Expand Down
3 changes: 3 additions & 0 deletions papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Metadata-Version: 2.3
Name: foo
Version: 0.0.1
2 changes: 2 additions & 0 deletions papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[papermill.tests.fake]
foo = bar
5 changes: 3 additions & 2 deletions papermill/tests/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def test_registering_entry_points(self):
fake_entrypoint = Mock(load=Mock())
fake_entrypoint.name = "fake-engine"

with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
entry_points = {"papermill.engine": [fake_entrypoint]}
with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points:
self.papermill_engines.register_entry_points()
mock_get_group_all.assert_called_once_with("papermill.engine")
mock_entry_points.assert_called_once()
self.assertEqual(self.papermill_engines.get_engine("fake-engine"), fake_entrypoint.load.return_value)
5 changes: 3 additions & 2 deletions papermill/tests/test_iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def test_entrypoint_register(self):
fake_entrypoint = Mock(load=Mock())
fake_entrypoint.name = "fake-from-entry-point://"

with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
entry_points = {"papermill.io": [fake_entrypoint]}
with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points:
self.papermill_io.register_entry_points()
mock_get_group_all.assert_called_once_with("papermill.io")
mock_entry_points.assert_called_once()
fake_ = self.papermill_io.get_handler("fake-from-entry-point://")
assert fake_ == fake_entrypoint.load.return_value

Expand Down
13 changes: 13 additions & 0 deletions papermill/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -10,6 +11,7 @@
from ..utils import (
any_tagged_cell,
chdir,
get_entrypoints_group,
merge_kwargs,
remove_args,
retry,
Expand Down Expand Up @@ -58,3 +60,14 @@ def test_chdir():
assert Path.cwd() == Path(temp_dir)

assert Path.cwd() == old_cwd


def test_get_entrypoints_group():
# We don't need to mock anything here, there is just enough metadata
# present to give us one entry point.
sys.path.insert(0, Path(__file__).parent / "fixtures")
# We need to cast to a list here, 3.8/3.9 and 3.10+ return different
# types.
eps = list(get_entrypoints_group("papermill.tests.fake"))
sys.path.pop()
assert eps[0].name == "foo"
18 changes: 18 additions & 0 deletions papermill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from contextlib import contextmanager
from functools import wraps
from importlib.metadata import entry_points

from .exceptions import PapermillParameterOverwriteWarning

Expand Down Expand Up @@ -190,3 +191,20 @@ def chdir(path):
yield
finally:
os.chdir(old_dir)


def get_entrypoints_group(group):
"""Return a given group of entrypoints.

Since the importlib.metadata entry points API is very simple in 3.8 and
more complete in 3.10+, we need to support both. This function can be
removed when 3.10 is the minimum supported version, and replaced
with entry_points(group=group).
"""
eps = entry_points()
if hasattr(eps, "select"):
# New and shiny Python 3.10+ API
return eps.select(group=group)
else:
# Python 3.8 and 3.9
return eps.get(group, [])
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ nbformat >= 5.2.0
nbclient >= 0.2.0
tqdm >= 4.32.2
requests
entrypoints
tenacity >= 5.0.2
aiohttp >=3.9.0; python_version=="3.12"
ansicolors