Skip to content

Commit

Permalink
pluggable azure credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
oavdeev committed Mar 5, 2024
1 parent 2574095 commit df4be4a
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 43 deletions.
1 change: 1 addition & 0 deletions metaflow/extension_support/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def resolve_plugins(category):
"metadata_provider": lambda x: x.TYPE,
"datastore": lambda x: x.TYPE,
"secrets_provider": lambda x: x.TYPE,
"azure_client_provider": lambda x: x.name,
"sidecar": None,
"logging_sidecar": None,
"monitor_sidecar": None,
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@
),
]

AZURE_CLIENT_PROVIDERS_DESC = [
("azure-default", ".azure.azure_credential.AzureDefaultClientProvider")
]


process_plugins(globals())


Expand All @@ -143,6 +148,7 @@ def get_plugin_cli():

AWS_CLIENT_PROVIDERS = resolve_plugins("aws_client_provider")
SECRETS_PROVIDERS = resolve_plugins("secrets_provider")
AZURE_CLIENT_PROVIDERS = resolve_plugins("azure_client_provider")

from .cards.card_modules import MF_EXTERNAL_CARDS

Expand Down
1 change: 1 addition & 0 deletions metaflow/plugins/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .azure_credential import create_cacheable_azure_credential as create_azure_credential
53 changes: 53 additions & 0 deletions metaflow/plugins/azure/azure_credential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
class AzureDefaultClientProvider(object):
name = "azure-default"

@staticmethod
def create_cacheable_azure_credential(*args, **kwargs):
"""azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary
because it does not have a content based hash and equality implementations.
We implement a subclass CacheableDefaultAzureCredential to add them.
We need this because credentials will be part of the cache key in _ClientCache.
"""
from azure.identity import DefaultAzureCredential

class CacheableDefaultAzureCredential(DefaultAzureCredential):
def __init__(self, *args, **kwargs):
super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs)
# Just hashing all the kwargs works because they are all individually
# hashable as of 7/15/2022.
#
# What if Azure adds unhashable things to kwargs?
# - We will have CI to catch this (it will always install the latest Azure SDKs)
# - In Metaflow usage today we never specify any kwargs anyway. (see last line
# of the outer function.
self._hash_code = hash((args, tuple(sorted(kwargs.items()))))

def __hash__(self):
return self._hash_code

def __eq__(self, other):
return hash(self) == hash(other)

return CacheableDefaultAzureCredential(*args, **kwargs)


cached_provider_class = None


def create_cacheable_azure_credential():
global cached_provider_class
if cached_provider_class is None:
from metaflow.metaflow_config import DEFAULT_AZURE_CLIENT_PROVIDER
from metaflow.plugins import AZURE_CLIENT_PROVIDERS

for p in AZURE_CLIENT_PROVIDERS:
if p.name == DEFAULT_AZURE_CLIENT_PROVIDER:
cached_provider_class = p
break
else:
raise ValueError(
"Cannot find Azure Client provider %s" % DEFAULT_AZURE_CLIENT_PROVIDER
)
return cached_provider_class.create_cacheable_azure_credential()
37 changes: 2 additions & 35 deletions metaflow/plugins/azure/azure_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
MetaflowAzurePackageError,
)
from metaflow.exception import MetaflowInternalError, MetaflowException
from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential


def _check_and_init_azure_deps():
Expand Down Expand Up @@ -138,38 +139,6 @@ def _inner_func(*args, **kwargs):
return _inner_func


@check_azure_deps
def create_cacheable_default_azure_credentials(*args, **kwargs):
"""azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary
because it does not have a content based hash and equality implementations.
We implement a subclass CacheableDefaultAzureCredential to add them.
We need this because credentials will be part of the cache key in _ClientCache.
"""
from azure.identity import DefaultAzureCredential

class CacheableDefaultAzureCredential(DefaultAzureCredential):
def __init__(self, *args, **kwargs):
super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs)
# Just hashing all the kwargs works because they are all individually
# hashable as of 7/15/2022.
#
# What if Azure adds unhashable things to kwargs?
# - We will have CI to catch this (it will always install the latest Azure SDKs)
# - In Metaflow usage today we never specify any kwargs anyway. (see last line
# of the outer function.
self._hash_code = hash((args, tuple(sorted(kwargs.items()))))

def __hash__(self):
return self._hash_code

def __eq__(self, other):
return hash(self) == hash(other)

return CacheableDefaultAzureCredential(*args, **kwargs)


@check_azure_deps
def create_static_token_credential(token_):
from azure.core.credentials import TokenCredential
Expand Down Expand Up @@ -200,9 +169,7 @@ def __init__(self, token):
def get_token(self, *_scopes, **_kwargs):

if (self._cached_token.expires_on - time.time()) < 300:
from azure.identity import DefaultAzureCredential

self._credential = DefaultAzureCredential()
self._credential = create_cacheable_azure_credential()
if self._credential:
return self._credential.get_token(*_scopes, **_kwargs)
return self._cached_token
Expand Down
6 changes: 4 additions & 2 deletions metaflow/plugins/azure/blob_service_client_factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
from metaflow.plugins.azure.azure_utils import (
create_cacheable_default_azure_credentials,
check_azure_deps,
)
from metaflow.plugins.azure.azure_credential import (
create_cacheable_azure_credential,
)

import os
import threading
Expand Down Expand Up @@ -125,7 +127,7 @@ def get_azure_blob_service_client(
blob_service_endpoint = AZURE_STORAGE_BLOB_SERVICE_ENDPOINT

if not credential:
credential = create_cacheable_default_azure_credentials()
credential = create_cacheable_azure_credential()
credential_is_cacheable = True

if not credential_is_cacheable:
Expand Down
12 changes: 6 additions & 6 deletions metaflow/plugins/datastores/azure_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
handle_executor_exceptions,
)

from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential

AZURE_STORAGE_DOWNLOAD_MAX_CONCURRENCY = 4
AZURE_STORAGE_UPLOAD_MAX_CONCURRENCY = 16

Expand Down Expand Up @@ -266,12 +268,10 @@ def _get_default_token(self):
if not self._default_scope_token or (
self._default_scope_token.expires_on - time.time() < 300
):
from azure.identity import DefaultAzureCredential

with DefaultAzureCredential() as credential:
self._default_scope_token = credential.get_token(
AZURE_STORAGE_DEFAULT_SCOPE
)
credential = create_cacheable_azure_credential()
self._default_scope_token = credential.get_token(
AZURE_STORAGE_DEFAULT_SCOPE
)
return self._default_scope_token

@property
Expand Down

0 comments on commit df4be4a

Please sign in to comment.