Skip to content

Commit

Permalink
[azure][trace] Check trace Cosmos status and log warnings if not ready (
Browse files Browse the repository at this point in the history
#3200)

# Description

This PR mainly targets to add a check for trace Cosmos DB status during
create Azure run. If the Cosmos is not ready, will log a warning on that
to hint user enable with `pf config set`.

**Warning message**

Local to cloud


![image](https://github.com/microsoft/promptflow/assets/38847871/0718e4cc-1d19-4e71-ab69-05b3bfc0b59e)

Submit run to Azure


![image](https://github.com/microsoft/promptflow/assets/38847871/e8159520-8fc5-4d53-8986-a064849d0332)

As this change will add one new request during `pfazure run create`, so
this PR also contains the changes that update the recordings.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [x] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.
  • Loading branch information
zhengfeiwang committed May 13, 2024
1 parent 95a73ac commit b01ce40
Show file tree
Hide file tree
Showing 58 changed files with 69,774 additions and 27,766 deletions.
1 change: 1 addition & 0 deletions src/promptflow-azure/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Behaviors not changed: 'pfazure connection' command will scrub secrets.
- New behavior: connection object by `client.connection.get` will have real secrets. `print(connection_obj)` directly will scrub those secrets. `print(connection_obj.api_key)` or `print(connection_obj.secrets)` will print the REAL secrets.
- Workspace listsecrets permission is required to get the secrets. Call `client.connection.get(name, with_secrets=True)` if you want to get without the secrets and listsecrets permission.
- [promptflow-azure] Check workspace/project trace Cosmos DB status and honor when create run in Azure.

## v1.10.0 (2024.04.26)

Expand Down
19 changes: 19 additions & 0 deletions src/promptflow-azure/promptflow/azure/_constants/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,22 @@
COSMOS_DB_SETUP_POLL_INTERVAL_SECOND = 30
COSMOS_DB_SETUP_POLL_PRINT_INTERVAL_SECOND = 30
COSMOS_DB_SETUP_RESOURCE_TYPE = "HOBO"


class CosmosConfiguration:
NONE = "None"
READ_DISABLED = "ReadDisabled"
WRITE_DISABLED = "WriteDisabled"
DISABLED = "Disabled"
DIAGNOSTIC_DISABLED = "DiagnosticDisabled"
DATA_CLEANED = "DataCleaned"
ACCOUNT_DELETED = "AccountDeleted"


class CosmosStatus:
NOT_EXISTS = "NotExists"
INITIALIZING = "Initializing"
INITIALIZED = "Initialized"
DELETING = "Deleting"
DELETED = "Deleted"
NOT_AVAILABLE = "NotAvailable"
27 changes: 27 additions & 0 deletions src/promptflow-azure/promptflow/azure/_entities/_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from dataclasses import dataclass

from promptflow.azure._constants._trace import CosmosConfiguration, CosmosStatus
from promptflow.azure._restclient.flow.models import TraceCosmosMetaDto


@dataclass
class CosmosMetadata:
configuration: str
status: str

@staticmethod
def _from_rest_object(obj: TraceCosmosMetaDto) -> "CosmosMetadata":
return CosmosMetadata(
configuration=obj.trace_cosmos_configuration,
status=obj.trace_cosmos_status,
)

def is_disabled(self) -> bool:
return self.configuration == CosmosConfiguration.DISABLED

def is_ready(self) -> bool:
return not self.is_disabled() and self.status == CosmosStatus.INITIALIZED
13 changes: 7 additions & 6 deletions src/promptflow-azure/promptflow/azure/_pf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,19 @@ def __init__(
workspace=workspace,
**kwargs,
)
self._traces = TraceOperations(
operation_scope=self._ml_client._operation_scope,
operation_config=self._ml_client._operation_config,
service_caller=self._service_caller,
**kwargs,
)
self._runs = RunOperations(
operation_scope=self._ml_client._operation_scope,
operation_config=self._ml_client._operation_config,
all_operations=self._ml_client._operation_container,
credential=self._ml_client._credential,
flow_operations=self._flows,
trace_operations=self._traces,
service_caller=self._service_caller,
workspace=workspace,
**kwargs,
Expand All @@ -109,12 +116,6 @@ def __init__(
service_caller=self._service_caller,
**kwargs,
)
self._traces = TraceOperations(
operation_scope=self._ml_client._operation_scope,
operation_config=self._ml_client._operation_config,
service_caller=self._service_caller,
**kwargs,
)

@staticmethod
def _validate_config_information(subscription_id, resource_group_name, workspace_name, kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
COSMOS_DB_SETUP_POLL_PRINT_INTERVAL_SECOND,
COSMOS_DB_SETUP_POLL_TIMEOUT_SECOND,
)
from promptflow.azure._constants._trace import (
COSMOS_DB_SETUP_POLL_INTERVAL_SECOND,
COSMOS_DB_SETUP_POLL_PRINT_INTERVAL_SECOND,
COSMOS_DB_SETUP_POLL_TIMEOUT_SECOND,
)
from promptflow.azure._restclient.flow import AzureMachineLearningDesignerServiceClient
from promptflow.azure._utils.general import get_authorization, get_arm_token, get_aml_token
from promptflow.exceptions import UserErrorException, PromptflowException, SystemErrorException
Expand Down Expand Up @@ -763,6 +768,22 @@ def init_workspace_cosmos(
**kwargs,
)

def get_workspace_cosmos_metadata(
self,
subscription_id: str,
resource_group_name: str,
workspace_name: str,
**kwargs,
):
"""Get Cosmos DB metadata."""
return self.caller.trace_sessions.get_trace_session_metadata_async(
subscription_id=subscription_id,
resource_group_name=resource_group_name,
workspace_name=workspace_name,
headers=self._get_headers(),
**kwargs,
)

@_request_wrapper()
def setup_workspace_cosmos(self, subscription_id, resource_group_name, workspace_name, body, **kwargs):
"""Setup Cosmos DB for workspace/project."""
Expand Down
59 changes: 52 additions & 7 deletions src/promptflow-azure/promptflow/azure/_utils/_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
import typing

from azure.ai.ml import MLClient
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import AzureCliCredential

from promptflow._constants import AzureWorkspaceKind, CosmosDBContainerName
from promptflow._constants import AzureWorkspaceKind
from promptflow._sdk._constants import AzureMLWorkspaceTriad
from promptflow._sdk._utilities.general_utils import extract_workspace_triad_from_trace_provider
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.azure import PFClient
from promptflow.azure._constants._trace import COSMOS_DB_SETUP_RESOURCE_TYPE
from promptflow.azure._restclient.flow_service_caller import FlowRequestException
from promptflow.azure._entities._trace import CosmosMetadata
from promptflow.exceptions import ErrorTarget, UserErrorException

_logger = get_cli_sdk_logger()
Expand All @@ -21,6 +25,42 @@ def _create_trace_destination_value_user_error(message: str) -> UserErrorExcepti
return UserErrorException(message=message, target=ErrorTarget.CONTROL_PLANE_SDK)


def resolve_disable_trace(metadata: CosmosMetadata, logger: typing.Optional[logging.Logger] = None) -> bool:
"""Resolve `disable_trace` from Cosmos DB metadata.
Only return True when the Cosmos DB is disabled; will log warning if the Cosmos DB is not ready.
"""
if logger is None:
logger = _logger
if metadata.is_disabled():
logger.debug("the trace cosmos db is disabled.")
return True
if not metadata.is_ready():
warning_message = (
"The trace Cosmos DB for current workspace/project is not ready yet, "
"your traces might not be logged and stored properly.\n"
"To enable it, please run `pf config set trace.destination="
"azureml://subscriptions/<subscription-id>/"
"resourceGroups/<resource-group-name>/providers/Microsoft.MachineLearningServices/"
"workspaces/<workspace-or-project-name>`, prompt flow will help to get everything ready.\n"
)
logger.warning(warning_message)
return False


def is_trace_cosmos_available(ws_triad: AzureMLWorkspaceTriad, logger: typing.Optional[logging.Logger] = None) -> bool:
if logger is None:
logger = _logger
pf_client = PFClient(
credential=AzureCliCredential(),
subscription_id=ws_triad.subscription_id,
resource_group_name=ws_triad.resource_group_name,
workspace_name=ws_triad.workspace_name,
)
cosmos_metadata = pf_client._traces._get_cosmos_metadata()
return not resolve_disable_trace(metadata=cosmos_metadata, logger=logger)


def validate_trace_destination(value: str) -> None:
"""Validate pf.config.trace.destination.
Expand Down Expand Up @@ -61,13 +101,16 @@ def validate_trace_destination(value: str) -> None:
_logger.debug("Resource type is valid.")

# the workspace Cosmos DB is initialized
# try to retrieve the token from PFS; if failed, call PFS init API and start polling
# if not, call PFS setup API and start polling
_logger.debug("Validating workspace Cosmos DB is initialized...")
pf_client = PFClient(ml_client=ml_client)
try:
pf_client._traces._get_cosmos_db_token(container_name=CosmosDBContainerName.SPAN)
_logger.debug("The workspace Cosmos DB is already initialized.")
except FlowRequestException:
cosmos_metadata = pf_client._traces._get_cosmos_metadata()
# raise error if the Cosmos DB is disabled
if cosmos_metadata.is_disabled():
error_message = "The workspace Cosmos DB is disabled, please enable it first."
_logger.error(error_message)
raise _create_trace_destination_value_user_error(error_message)
if not cosmos_metadata.is_ready():
# print here to let users aware this operation as it's kind of time consuming
init_cosmos_msg = (
"The workspace Cosmos DB is not initialized yet, "
Expand All @@ -76,6 +119,8 @@ def validate_trace_destination(value: str) -> None:
print(init_cosmos_msg)
_logger.debug(init_cosmos_msg)
pf_client._traces._setup_cosmos_db(resource_type=COSMOS_DB_SETUP_RESOURCE_TYPE)
else:
_logger.debug("The workspace Cosmos DB is available.")
_logger.debug("The workspace Cosmos DB is initialized.")

_logger.debug("pf.config.trace.destination is valid.")
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow._utils.utils import in_jupyter_notebook
from promptflow.azure._constants._flow import AUTOMATIC_RUNTIME, AUTOMATIC_RUNTIME_NAME, CLOUD_RUNS_PAGE_SIZE
from promptflow.azure._entities._trace import CosmosMetadata
from promptflow.azure._load_functions import load_flow
from promptflow.azure._restclient.flow_service_caller import FlowServiceCaller
from promptflow.azure._utils.general import get_authorization, get_user_alias_from_credential, set_event_loop_policy
from promptflow.azure.operations._flow_operations import FlowOperations
from promptflow.azure.operations._trace_operations import TraceOperations
from promptflow.exceptions import UserErrorException

RUNNING_STATUSES = RunStatus.get_running_statuses()
Expand Down Expand Up @@ -87,6 +89,7 @@ def __init__(
operation_config: OperationConfig,
all_operations: OperationsContainer,
flow_operations: FlowOperations,
trace_operations: TraceOperations,
credential,
service_caller: FlowServiceCaller,
workspace: Workspace,
Expand All @@ -106,6 +109,7 @@ def __init__(
self._identity = workspace.identity
self._credential = credential
self._flow_operations = flow_operations
self._trace_operations = trace_operations
self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config)

@property
Expand Down Expand Up @@ -844,24 +848,33 @@ def _resolve_runtime(self, run, runtime):
raise TypeError(f"runtime should be a string, got {type(runtime)} for {runtime}")
return runtime

def _resolve_dependencies_in_parallel(self, run, runtime, reset=None):
def _get_cosmos_metadata(self) -> CosmosMetadata:
return self._trace_operations._get_cosmos_metadata()

def _resolve_dependencies_in_parallel(self, run: Run, runtime, reset=None):
# local import to avoid circular import related to PFClient
from promptflow.azure._utils._tracing import resolve_disable_trace

with ThreadPoolExecutor() as pool:
tasks = [
pool.submit(self._resolve_data_to_asset_id, run=run),
pool.submit(self._resolve_flow_and_session_id, run=run),
pool.submit(self._get_cosmos_metadata),
]
concurrent.futures.wait(tasks, return_when=concurrent.futures.ALL_COMPLETED)
task_results = [task.result() for task in tasks]

run.data = task_results[0]
run.flow, session_id = task_results[1]
cosmos_metadata = task_results[2]

runtime = self._resolve_runtime(run=run, runtime=runtime)
self._resolve_identity(run=run)

rest_obj = run._to_rest_object()
rest_obj.runtime_name = runtime
rest_obj.session_id = session_id
rest_obj.disable_trace = resolve_disable_trace(metadata=cosmos_metadata, logger=logger)

# TODO(2884482): support force reset & force install

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations

from promptflow._sdk._telemetry import ActivityType, WorkspaceTelemetryMixin, monitor_operation
from promptflow.azure._entities._trace import CosmosMetadata
from promptflow.azure._restclient.flow.models import TraceDbSetupRequest
from promptflow.azure._restclient.flow_service_caller import FlowServiceCaller

Expand All @@ -32,6 +33,7 @@ def __init__(

@monitor_operation(activity_name="pfazure.traces._init_cosmos_db", activity_type=ActivityType.INTERNALCALL)
def _init_cosmos_db(self) -> Optional[Dict]:
# this API is deprecated and will be removed in the future
resp = self._service_caller.init_workspace_cosmos(
subscription_id=self._operation_scope.subscription_id,
resource_group_name=self._operation_scope.resource_group_name,
Expand Down Expand Up @@ -63,3 +65,11 @@ def _setup_cosmos_db(self, resource_type: str) -> None:
body=body,
)
return

def _get_cosmos_metadata(self) -> CosmosMetadata:
rest_obj = self._service_caller.get_workspace_cosmos_metadata(
subscription_id=self._operation_scope.subscription_id,
resource_group_name=self._operation_scope.resource_group_name,
workspace_name=self._operation_scope.workspace_name,
)
return CosmosMetadata._from_rest_object(rest_obj)
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ def test_run_bulk_without_retry(self, remote_client):
from azure.core.rest._requests_basic import RestRequestsTransportResponse
from requests import Response

from promptflow.azure._constants._trace import CosmosConfiguration, CosmosStatus
from promptflow.azure._entities._trace import CosmosMetadata
from promptflow.azure._restclient.flow.models import SubmitBulkRunRequest
from promptflow.azure._restclient.flow_service_caller import FlowRequestException, FlowServiceCaller
from promptflow.azure.operations import RunOperations
Expand All @@ -606,9 +608,14 @@ def collect_submit_call_count(_call_args_list):
mock_run._use_remote_flow = False
mock_run._identity = None

mock_cosmos_metadata = CosmosMetadata(
configuration=CosmosConfiguration.DIAGNOSTIC_DISABLED,
status=CosmosStatus.INITIALIZED,
)

with patch.object(RunOperations, "_resolve_data_to_asset_id"), patch.object(
RunOperations, "_resolve_flow_and_session_id", return_value=("fake_flow_id", "fake_session_id")
):
), patch.object(RunOperations, "_get_cosmos_metadata", return_value=mock_cosmos_metadata):
with patch.object(RequestsTransport, "send") as mock_request, patch.object(
FlowServiceCaller, "_set_headers_with_user_aml_token"
):
Expand All @@ -625,7 +632,7 @@ def collect_submit_call_count(_call_args_list):

with patch.object(RunOperations, "_resolve_data_to_asset_id"), patch.object(
RunOperations, "_resolve_flow_and_session_id", return_value=("fake_flow_id", "fake_session_id")
):
), patch.object(RunOperations, "_get_cosmos_metadata", return_value=mock_cosmos_metadata):
with patch.object(RequestsTransport, "send") as mock_request, patch.object(
FlowServiceCaller, "_set_headers_with_user_aml_token"
):
Expand All @@ -645,7 +652,7 @@ def collect_submit_call_count(_call_args_list):

with patch.object(RunOperations, "_resolve_data_to_asset_id"), patch.object(
RunOperations, "_resolve_flow_and_session_id", return_value=("fake_flow_id", "fake_session_id")
):
), patch.object(RunOperations, "_get_cosmos_metadata", return_value=mock_cosmos_metadata):
with patch.object(RequestsTransport, "send") as mock_request, patch.object(
FlowServiceCaller, "_set_headers_with_user_aml_token"
):
Expand Down
15 changes: 14 additions & 1 deletion src/promptflow-devkit/promptflow/_sdk/_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,19 @@ def start_trace_with_devkit(collection: str, **kwargs: typing.Any) -> None:
)
_logger.warning(warning_msg)

# enable trace by default, only disable it when we get "Disabled" status from service side
# this operation requires Azure dependencies as client talks to PFS
disable_trace_in_cloud = False
if ws_triad is not None and is_azure_ext_installed:
from promptflow.azure._utils._tracing import is_trace_cosmos_available

if not is_trace_cosmos_available(ws_triad=ws_triad, logger=_logger):
disable_trace_in_cloud = True
# if trace is disabled, directly set workspace triad as None
# so that following code will regard as no workspace configured
if disable_trace_in_cloud is True:
ws_triad = None

# invoke prompt flow service
pfs_port = _invoke_pf_svc()
is_pfs_healthy = is_pfs_service_healthy(pfs_port)
Expand Down Expand Up @@ -446,7 +459,7 @@ def start_trace_with_devkit(collection: str, **kwargs: typing.Any) -> None:
# print tracing url(s) when run is specified
_print_tracing_url_from_local(pfs_port=pfs_port, collection=collection, exp=exp, run=run)

if run is not None and run_config._is_cloud_trace_destination(path=flow_path):
if run is not None and ws_triad is not None:
trace_destination = run_config.get_trace_destination(path=flow_path)
print(
f"You can view the traces in azure portal since trace destination is set to: {trace_destination}. "
Expand Down
5 changes: 5 additions & 0 deletions src/promptflow-recording/promptflow/recording/azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ def sanitize_pfs_response_body(body: str) -> str:
# BulkRuns/{flowRunId}
if "studioPortalEndpoint" in body:
body_dict["studioPortalEndpoint"] = sanitize_azure_workspace_triad(body_dict["studioPortalEndpoint"])
# TraceSessions
if "accountEndpoint" in body:
body_dict["accountEndpoint"] = ""
if "resourceArmId" in body:
body_dict["resourceArmId"] = ""
return json.dumps(body_dict)


Expand Down

0 comments on commit b01ce40

Please sign in to comment.