Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Traceability of Generator Outputs in Promptflow Tracing #3120

Merged
merged 17 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@
from promptflow.tracing._integrations._openai_injector import inject_openai_api
from promptflow.tracing._operation_context import OperationContext
from promptflow.tracing._start_trace import setup_exporter_from_environ
from promptflow.tracing._trace import enrich_span_with_context, enrich_span_with_input, enrich_span_with_trace_type
from promptflow.tracing._trace import (
enrich_span_with_context,
enrich_span_with_input,
enrich_span_with_trace_type,
start_as_current_span,
)
from promptflow.tracing.contracts.trace import TraceType

DEFAULT_TRACING_KEYS = {"run_mode", "root_run_id", "flow_id", "batch_input_source", "execution_target"}
Expand Down Expand Up @@ -880,7 +885,7 @@ def _get_node_referenced_flow_inputs(
@contextlib.contextmanager
def _start_flow_span(self, inputs: Mapping[str, Any]):
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(self._flow.name) as span:
with start_as_current_span(otel_tracer, self._flow.name) as span:
liucheng-ms marked this conversation as resolved.
Show resolved Hide resolved
# Store otel trace id in context for correlation
OperationContext.get_instance()["otel_trace_id"] = f"0x{format_trace_id(span.get_span_context().trace_id)}"
# initialize span
Expand Down Expand Up @@ -909,7 +914,7 @@ async def _exec_inner_with_trace_async(
context: FlowExecutionContext,
stream=False,
):
with self._start_flow_span(inputs) as span, self._record_cancellation_exceptions_to_span(span):
with self._start_flow_span(inputs) as span:
output, nodes_outputs = await self._traverse_nodes_async(inputs, context)
output = await self._stringify_generator_output_async(output) if not stream else output
self._exec_post_process(inputs, output, nodes_outputs, run_info, run_tracker, span, stream)
Expand All @@ -923,7 +928,7 @@ def _exec_inner_with_trace(
context: FlowExecutionContext,
stream=False,
):
with self._start_flow_span(inputs) as span, self._record_cancellation_exceptions_to_span(span):
with self._start_flow_span(inputs) as span:
output, nodes_outputs = self._traverse_nodes(inputs, context)
output = self._stringify_generator_output(output) if not stream else output
self._exec_post_process(inputs, output, nodes_outputs, run_info, run_tracker, span, stream)
Expand Down
11 changes: 8 additions & 3 deletions src/promptflow-tracing/promptflow/tracing/_openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,14 @@ def __init__(self, response, is_chat):

@property
def model(self):
for item in self._response:
if hasattr(item, "model"):
return item.model
"""
This method iterates over each item in the _response list.
If the item has a non-empty 'model' attribute, it returns the model.
If no such item is found, it returns None.
"""
for response_item in self._response:
if hasattr(response_item, "model") and response_item.model:
return response_item.model
return None

@property
Expand Down
169 changes: 108 additions & 61 deletions src/promptflow-tracing/promptflow/tracing/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from collections.abc import AsyncIterator, Iterator
from importlib.metadata import version
from threading import Lock
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Sequence

import opentelemetry.trace as otel_trace
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import Link, Span
from opentelemetry.trace.span import NonRecordingSpan, format_trace_id
from opentelemetry.trace.status import StatusCode
from opentelemetry.trace import Span
from opentelemetry.trace.span import format_trace_id
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types

from ._openai_utils import OpenAIMetricsCalculator, OpenAIResponseParser
from ._operation_context import OperationContext
Expand All @@ -41,6 +41,73 @@ def _record_cancellation_exceptions_to_span(span: Span):
raise


def handle_span_exception(span, exception):
if isinstance(span, Span) and span.is_recording():
# Record the exception as an event
span.record_exception(exception)
# Records status as error
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exception).__name__}: {str(exception)}",
)
)


def handle_output(span, inputs, output, trace_type):
if isinstance(output, (Iterator, AsyncIterator)):
# Set should_end to False to delay span end until generator exhaustion, preventing premature span end.
setattr(span, "__should_end", False)
output = Tracer.pop(output)
if isinstance(output, Iterator):
return traced_generator(span, inputs, output, trace_type)
else:
return traced_async_generator(span, inputs, output, trace_type)
else:
enrich_span_with_trace_type(span, inputs, output, trace_type)
span.set_status(StatusCode.OK)
return Tracer.pop(output)


@contextlib.contextmanager
def start_as_current_span(
liucheng-ms marked this conversation as resolved.
Show resolved Hide resolved
tracer: otel_trace.Tracer,
name: str,
context: Optional[otel_trace.Context] = None,
kind: otel_trace.SpanKind = otel_trace.SpanKind.INTERNAL,
attributes: types.Attributes = None,
links: Optional[Sequence[otel_trace.Link]] = (),
start_time: Optional[int] = None,
record_exception: bool = True,
set_status_on_exception: bool = True,
):
span = None
try:
with tracer.start_as_current_span(
name,
context,
kind,
attributes,
links,
start_time,
record_exception,
set_status_on_exception,
end_on_exit=False,
) as span:
setattr(span, "__should_end", True)
yield span

except (KeyboardInterrupt, asyncio.CancelledError) as ex:
# The context manager above does not handle KeyboardInterrupt and asyncio.CancelledError exceptions.
# Therefore, we need to explicitly handle these exceptions here to ensure proper span exception handling.
handle_span_exception(span, ex)
liucheng-ms marked this conversation as resolved.
Show resolved Hide resolved
raise

finally:
if span is not None and getattr(span, "__should_end", False):
span.end()


class TokenCollector:
_lock = Lock()

Expand Down Expand Up @@ -152,18 +219,10 @@ def enrich_span_with_trace_type(span, inputs, output, trace_type):
SpanEnricherManager.enrich(span, inputs, output, trace_type)
# TODO: Move the following logic to SpanEnricher
enrich_span_with_openai_tokens(span, trace_type)
return trace_iterator_if_needed(span, inputs, output)


def trace_iterator_if_needed(span, inputs, output):
if isinstance(output, (Iterator, AsyncIterator)) and not isinstance(span, NonRecordingSpan):
trace_func = traced_generator if isinstance(output, Iterator) else traced_async_generator
output = trace_func(span, inputs, output)
return output


def enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output):
if original_span.attributes["span_type"] == "LLM" and not IS_LEGACY_OPENAI:
def enrich_span_with_llm_if_needed(span, inputs, generator_output):
if span.is_recording() and span.attributes["span_type"] == "LLM" and not IS_LEGACY_OPENAI:
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion import Completion

Expand All @@ -173,53 +232,45 @@ def enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output
token_collector.collect_openai_tokens_for_streaming(span, inputs, generator_output, parser.is_chat)


def traced_generator(original_span: ReadableSpan, inputs, generator):
context = original_span.get_span_context()
link = Link(context)
# If start_trace is not called, the name of the original_span will be empty.
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(
f"Iterated({original_span.name})",
links=[link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, inputs)
def traced_generator(span, inputs, generator, trace_type):
try:
generator_proxy = GeneratorProxy(generator)
yield from generator_proxy

generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))

enrich_span_with_llm_if_needed(span, inputs, generator_output)
enrich_span_with_trace_type(span, inputs, generator_output, trace_type)
token_collector.collect_openai_tokens_for_parent_span(span)
span.set_attribute("output_type", "iterated")
span.set_status(StatusCode.OK)
token_collector.collect_openai_tokens_for_parent_span(span)


async def traced_async_generator(original_span: ReadableSpan, inputs, generator):
context = original_span.get_span_context()
link = Link(context)
# If start_trace is not called, the name of the original_span will be empty.
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(
f"Iterated({original_span.name})",
links=[link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, inputs)
except Exception as e:
handle_span_exception(span, e)
raise
finally:
# Always end the span on function exit, as the context manager doesn't handle this.
span.end()


async def traced_async_generator(span, inputs, generator, trace_type):
try:
generator_proxy = AsyncGeneratorProxy(generator)
async for item in generator_proxy:
yield item

generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))

enrich_span_with_llm_if_needed(span, inputs, generator_output)
enrich_span_with_trace_type(span, inputs, generator_output, trace_type)
token_collector.collect_openai_tokens_for_parent_span(span)
span.set_attribute("output_type", "iterated")
span.set_status(StatusCode.OK)
token_collector.collect_openai_tokens_for_parent_span(span)
except Exception as e:
handle_span_exception(span, e)
raise
finally:
# Always end the span on function exit, as the context manager doesn't handle this.
span.end()


def enrich_span_with_original_attributes(span, attributes):
Expand Down Expand Up @@ -383,7 +434,7 @@ async def wrapped(*args, **kwargs):
span_name = get_node_name_from_context(used_for_span_name=True) or trace.name
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(span_name) as span, _record_cancellation_exceptions_to_span(span):
with start_as_current_span(otel_tracer, span_name) as span:
# Store otel trace id in context for correlation
OperationContext.get_instance()["otel_trace_id"] = f"0x{format_trace_id(span.get_span_context().trace_id)}"
enrich_span_with_trace(span, trace)
Expand All @@ -396,9 +447,7 @@ async def wrapped(*args, **kwargs):
Tracer.push(trace)
enrich_span_with_input(span, trace.inputs)
output = await func(*args, **kwargs)
output = enrich_span_with_trace_type(span, trace.inputs, output, trace_type)
span.set_status(StatusCode.OK)
output = Tracer.pop(output)
output = handle_output(span, trace.inputs, output, trace_type)
except Exception as e:
liucheng-ms marked this conversation as resolved.
Show resolved Hide resolved
Tracer.pop(None, e)
raise
Expand Down Expand Up @@ -449,7 +498,7 @@ def wrapped(*args, **kwargs):
span_name = get_node_name_from_context(used_for_span_name=True) or trace.name
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(span_name) as span, _record_cancellation_exceptions_to_span(span):
with start_as_current_span(otel_tracer, span_name) as span:
# Store otel trace id in context for correlation
OperationContext.get_instance()["otel_trace_id"] = f"0x{format_trace_id(span.get_span_context().trace_id)}"
enrich_span_with_trace(span, trace)
Expand All @@ -462,9 +511,7 @@ def wrapped(*args, **kwargs):
Tracer.push(trace)
enrich_span_with_input(span, trace.inputs)
output = func(*args, **kwargs)
output = enrich_span_with_trace_type(span, trace.inputs, output, trace_type)
span.set_status(StatusCode.OK)
output = Tracer.pop(output)
output = handle_output(span, trace.inputs, output, trace_type)
except Exception as e:
Tracer.pop(None, e)
raise
Expand Down
59 changes: 26 additions & 33 deletions src/promptflow-tracing/tests/e2etests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ def assert_otel_trace(self, func, inputs, expected_span_length):
},
2,
),
(prompt_tpl_completion, {"prompt_tpl": "Hello {{name}}", "name": "world", "stream": True}, 3),
(prompt_tpl_completion, {"prompt_tpl": "Hello {{name}}", "name": "world", "stream": True}, 2),
(
prompt_tpl_chat,
{
"prompt_tpl": "system:\nYou are a helpful assistant.\n\n\nuser:\n{{question}}",
"question": "What is ChatGPT?",
"stream": True,
},
3,
2,
),
],
)
Expand Down Expand Up @@ -169,12 +169,12 @@ def assert_otel_traces_with_prompt(self, dev_connections, func, inputs, expected
[
(openai_chat, {"prompt": "Hello"}, 2),
(openai_completion, {"prompt": "Hello"}, 2),
(openai_chat, {"prompt": "Hello", "stream": True}, 3),
(openai_completion, {"prompt": "Hello", "stream": True}, 3),
(openai_chat, {"prompt": "Hello", "stream": True}, 2),
(openai_completion, {"prompt": "Hello", "stream": True}, 2),
(openai_chat_async, {"prompt": "Hello"}, 2),
(openai_completion_async, {"prompt": "Hello"}, 2),
(openai_chat_async, {"prompt": "Hello", "stream": True}, 3),
(openai_completion_async, {"prompt": "Hello", "stream": True}, 3),
(openai_chat_async, {"prompt": "Hello", "stream": True}, 2),
(openai_completion_async, {"prompt": "Hello", "stream": True}, 2),
],
)
def test_otel_trace_with_llm(self, dev_connections, func, inputs, expected_span_length):
Expand Down Expand Up @@ -315,35 +315,15 @@ def validate_span_events(self, span):
assert FUNCTION_OUTPUT_EVENT in events, f"Expected '{FUNCTION_OUTPUT_EVENT}' in events"

if span.attributes[SPAN_TYPE_ATTRIBUTE] == TraceType.LLM:
self.validate_llm_event(span, events)
self.validate_llm_event(events)
elif span.attributes[SPAN_TYPE_ATTRIBUTE] == TraceType.EMBEDDING:
self.validate_embedding_event(events)

if PROMPT_TEMPLATE_EVENT in events:
self.validate_prompt_template_event(events)

def validate_llm_event(self, span, span_events):
is_stream = span_events[FUNCTION_INPUTS_EVENT].get("stream", False)
is_iterated_span = span.name.startswith(ITERATED_SPAN_PREFIX)

# iterate span should have LLM_GENERATED_MESSAGE_EVENT and links
if is_iterated_span:
assert (
LLM_GENERATED_MESSAGE_EVENT in span_events
), f"Expected '{LLM_GENERATED_MESSAGE_EVENT}' in iterated span events"
assert span.links, "Expected links in iterated span"

# non-stream span should have LLM_GENERATED_MESSAGE_EVENT
if not is_stream:
assert (
LLM_GENERATED_MESSAGE_EVENT in span_events
), f"Expected '{LLM_GENERATED_MESSAGE_EVENT}' in non-stream span events"

# original span in streaming mode should not have LLM_GENERATED_MESSAGE_EVENT
if is_stream and not is_iterated_span:
assert (
LLM_GENERATED_MESSAGE_EVENT not in span_events
), f"Unexpected '{LLM_GENERATED_MESSAGE_EVENT}' in original span in streaming mode"
def validate_llm_event(self, span_events):
assert LLM_GENERATED_MESSAGE_EVENT in span_events, f"Expected '{LLM_GENERATED_MESSAGE_EVENT}' in span events"

def validate_embedding_event(self, span_events):
assert EMBEDDING_EVENT in span_events, f"Expected '{EMBEDDING_EVENT}' in span events"
Expand Down Expand Up @@ -446,10 +426,23 @@ def validate_openai_tokens(self, span_list, is_stream=False):
assert span.attributes[token_name] == expected_tokens[span_id][token_name]

def _is_llm_span_with_tokens(self, span, is_stream):
# For streaming mode, there are two spans for openai api call, one is the original span, and the other
# is the iterated span, which name is "Iterated(<original_trace_name>)", we should check the iterated span
# in streaming mode.
"""
This function checks if a given span is a LLM span with tokens.

If in stream mode, the function checks if the span has attributes indicating it's an iterated span.
In non-stream mode, it simply checks if the span's function attribute is in the list of LLM function names.

Args:
span: The span to check.
is_stream: A boolean indicating whether the span is in stream mode.

Returns:
A boolean indicating whether the span is a LLM span with tokens.
"""
if is_stream:
return span.attributes.get("function", "") in LLM_FUNCTION_NAMES and span.name.startswith("Iterated(")
return (
span.attributes.get("function", "") in LLM_FUNCTION_NAMES
and span.attributes.get("output_type", "") == "iterated"
)
else:
return span.attributes.get("function", "") in LLM_FUNCTION_NAMES