Skip to content

Commit

Permalink
Enhance Traceability of Generator Outputs in Promptflow Tracing (#3120)
Browse files Browse the repository at this point in the history
# Description

This PR introduces an enhanced method for capturing the trace of
generator output functions within the Promptflow tracing framework. In
the previous approach, when a function returned a generator, the
associated span would terminate, and the output would be a string
representing the generator object. Subsequently, a new span would be
initiated when another function consumed this generator.

With the modifications proposed in this PR, the span persists beyond the
point where a function returns a generator, concluding only after the
generator has been fully consumed. Consequently, the output of the span
encapsulating the generator or iterator object is now a comprehensive
list of the iterated objects.

These changes significantly improve the transparency and intelligibility
of the tracing process for generator outputs.

## Example with OpenAI Call

Consider the following code for a node:

```python
@trace
def chat(connection: AzureOpenAIConnection, question: str, stream: bool = False):
    connection_dict = normalize_connection_config(connection)
    client = AzureOpenAI(**connection_dict)

    messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": question}]
    response = client.chat.completions.create(model="gpt-35-turbo", messages=messages, stream=stream)

    if stream:
        def generator():
            for chunk in response:
                if chunk.choices:
                    yield chunk.choices[0].delta.content or ""
        return "".join(generator())
    return response.choices[0].message.content or ""

@tool
def my_python_tool(connection: AzureOpenAIConnection, question: str, stream: bool) -> str:
    return chat(connection, question, stream)
```

With the changes in this PR, specifying `stream=True` yields a generator
object from the `client.chat.completions.create` call.

### Original Implementation


![image](https://github.com/microsoft/promptflow/assets/51689021/d799b605-fd93-4e87-9a45-5e07ddb94392)

- In the original implementation, two spans were generated for
`openai_chat`. The first span ended when the call finished, and the
output was a string representing the generator object. The second span
started when the generator was consumed and ended when the generator was
fully consumed.

### New Implementation


![image](https://github.com/microsoft/promptflow/assets/51689021/9bf40499-62f2-4a7f-9b11-a72fb4b1bb12)

- In the new implementation, only one span is generated for
`openai_chat`. The span ends when the generator is fully consumed.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.
  • Loading branch information
liucheng-ms committed May 14, 2024
1 parent c1f9884 commit cd07d6d
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 116 deletions.
13 changes: 9 additions & 4 deletions src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@
from promptflow.tracing._integrations._openai_injector import inject_openai_api
from promptflow.tracing._operation_context import OperationContext
from promptflow.tracing._start_trace import setup_exporter_from_environ
from promptflow.tracing._trace import enrich_span_with_context, enrich_span_with_input, enrich_span_with_trace_type
from promptflow.tracing._trace import (
enrich_span_with_context,
enrich_span_with_input,
enrich_span_with_trace_type,
start_as_current_span,
)
from promptflow.tracing.contracts.trace import TraceType

DEFAULT_TRACING_KEYS = {"run_mode", "root_run_id", "flow_id", "batch_input_source", "execution_target"}
Expand Down Expand Up @@ -880,7 +885,7 @@ def _get_node_referenced_flow_inputs(
@contextlib.contextmanager
def _start_flow_span(self, inputs: Mapping[str, Any]):
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(self._flow.name) as span:
with start_as_current_span(otel_tracer, self._flow.name) as span:
# Store otel trace id in context for correlation
OperationContext.get_instance()["otel_trace_id"] = f"0x{format_trace_id(span.get_span_context().trace_id)}"
# initialize span
Expand Down Expand Up @@ -909,7 +914,7 @@ async def _exec_inner_with_trace_async(
context: FlowExecutionContext,
stream=False,
):
with self._start_flow_span(inputs) as span, self._record_cancellation_exceptions_to_span(span):
with self._start_flow_span(inputs) as span:
output, nodes_outputs = await self._traverse_nodes_async(inputs, context)
output = await self._stringify_generator_output_async(output) if not stream else output
self._exec_post_process(inputs, output, nodes_outputs, run_info, run_tracker, span, stream)
Expand All @@ -923,7 +928,7 @@ def _exec_inner_with_trace(
context: FlowExecutionContext,
stream=False,
):
with self._start_flow_span(inputs) as span, self._record_cancellation_exceptions_to_span(span):
with self._start_flow_span(inputs) as span:
output, nodes_outputs = self._traverse_nodes(inputs, context)
output = self._stringify_generator_output(output) if not stream else output
self._exec_post_process(inputs, output, nodes_outputs, run_info, run_tracker, span, stream)
Expand Down
11 changes: 8 additions & 3 deletions src/promptflow-tracing/promptflow/tracing/_openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,14 @@ def __init__(self, response, is_chat):

@property
def model(self):
for item in self._response:
if hasattr(item, "model"):
return item.model
"""
This method iterates over each item in the _response list.
If the item has a non-empty 'model' attribute, it returns the model.
If no such item is found, it returns None.
"""
for response_item in self._response:
if hasattr(response_item, "model") and response_item.model:
return response_item.model
return None

@property
Expand Down
169 changes: 108 additions & 61 deletions src/promptflow-tracing/promptflow/tracing/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from collections.abc import AsyncIterator, Iterator
from importlib.metadata import version
from threading import Lock
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Sequence

import opentelemetry.trace as otel_trace
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import Link, Span
from opentelemetry.trace.span import NonRecordingSpan, format_trace_id
from opentelemetry.trace.status import StatusCode
from opentelemetry.trace import Span
from opentelemetry.trace.span import format_trace_id
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types

from ._openai_utils import OpenAIMetricsCalculator, OpenAIResponseParser
from ._operation_context import OperationContext
Expand All @@ -41,6 +41,73 @@ def _record_cancellation_exceptions_to_span(span: Span):
raise


def handle_span_exception(span, exception):
if isinstance(span, Span) and span.is_recording():
# Record the exception as an event
span.record_exception(exception)
# Records status as error
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exception).__name__}: {str(exception)}",
)
)


def handle_output(span, inputs, output, trace_type):
if isinstance(output, (Iterator, AsyncIterator)):
# Set should_end to False to delay span end until generator exhaustion, preventing premature span end.
setattr(span, "__should_end", False)
output = Tracer.pop(output)
if isinstance(output, Iterator):
return traced_generator(span, inputs, output, trace_type)
else:
return traced_async_generator(span, inputs, output, trace_type)
else:
enrich_span_with_trace_type(span, inputs, output, trace_type)
span.set_status(StatusCode.OK)
return Tracer.pop(output)


@contextlib.contextmanager
def start_as_current_span(
tracer: otel_trace.Tracer,
name: str,
context: Optional[otel_trace.Context] = None,
kind: otel_trace.SpanKind = otel_trace.SpanKind.INTERNAL,
attributes: types.Attributes = None,
links: Optional[Sequence[otel_trace.Link]] = (),
start_time: Optional[int] = None,
record_exception: bool = True,
set_status_on_exception: bool = True,
):
span = None
try:
with tracer.start_as_current_span(
name,
context,
kind,
attributes,
links,
start_time,
record_exception,
set_status_on_exception,
end_on_exit=False,
) as span:
setattr(span, "__should_end", True)
yield span

except (KeyboardInterrupt, asyncio.CancelledError) as ex:
# The context manager above does not handle KeyboardInterrupt and asyncio.CancelledError exceptions.
# Therefore, we need to explicitly handle these exceptions here to ensure proper span exception handling.
handle_span_exception(span, ex)
raise

finally:
if span is not None and getattr(span, "__should_end", False):
span.end()


class TokenCollector:
_lock = Lock()

Expand Down Expand Up @@ -152,18 +219,10 @@ def enrich_span_with_trace_type(span, inputs, output, trace_type):
SpanEnricherManager.enrich(span, inputs, output, trace_type)
# TODO: Move the following logic to SpanEnricher
enrich_span_with_openai_tokens(span, trace_type)
return trace_iterator_if_needed(span, inputs, output)


def trace_iterator_if_needed(span, inputs, output):
if isinstance(output, (Iterator, AsyncIterator)) and not isinstance(span, NonRecordingSpan):
trace_func = traced_generator if isinstance(output, Iterator) else traced_async_generator
output = trace_func(span, inputs, output)
return output


def enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output):
if original_span.attributes["span_type"] == "LLM" and not IS_LEGACY_OPENAI:
def enrich_span_with_llm_if_needed(span, inputs, generator_output):
if span.is_recording() and span.attributes["span_type"] == "LLM" and not IS_LEGACY_OPENAI:
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion import Completion

Expand All @@ -173,53 +232,45 @@ def enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output
token_collector.collect_openai_tokens_for_streaming(span, inputs, generator_output, parser.is_chat)


def traced_generator(original_span: ReadableSpan, inputs, generator):
context = original_span.get_span_context()
link = Link(context)
# If start_trace is not called, the name of the original_span will be empty.
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(
f"Iterated({original_span.name})",
links=[link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, inputs)
def traced_generator(span, inputs, generator, trace_type):
try:
generator_proxy = GeneratorProxy(generator)
yield from generator_proxy

generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))

enrich_span_with_llm_if_needed(span, inputs, generator_output)
enrich_span_with_trace_type(span, inputs, generator_output, trace_type)
token_collector.collect_openai_tokens_for_parent_span(span)
span.set_attribute("output_type", "iterated")
span.set_status(StatusCode.OK)
token_collector.collect_openai_tokens_for_parent_span(span)


async def traced_async_generator(original_span: ReadableSpan, inputs, generator):
context = original_span.get_span_context()
link = Link(context)
# If start_trace is not called, the name of the original_span will be empty.
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(
f"Iterated({original_span.name})",
links=[link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, inputs)
except Exception as e:
handle_span_exception(span, e)
raise
finally:
# Always end the span on function exit, as the context manager doesn't handle this.
span.end()


async def traced_async_generator(span, inputs, generator, trace_type):
try:
generator_proxy = AsyncGeneratorProxy(generator)
async for item in generator_proxy:
yield item

generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))

enrich_span_with_llm_if_needed(span, inputs, generator_output)
enrich_span_with_trace_type(span, inputs, generator_output, trace_type)
token_collector.collect_openai_tokens_for_parent_span(span)
span.set_attribute("output_type", "iterated")
span.set_status(StatusCode.OK)
token_collector.collect_openai_tokens_for_parent_span(span)
except Exception as e:
handle_span_exception(span, e)
raise
finally:
# Always end the span on function exit, as the context manager doesn't handle this.
span.end()


def enrich_span_with_original_attributes(span, attributes):
Expand Down Expand Up @@ -383,7 +434,7 @@ async def wrapped(*args, **kwargs):
span_name = get_node_name_from_context(used_for_span_name=True) or trace.name
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(span_name) as span, _record_cancellation_exceptions_to_span(span):
with start_as_current_span(otel_tracer, span_name) as span:
# Store otel trace id in context for correlation
OperationContext.get_instance()["otel_trace_id"] = f"0x{format_trace_id(span.get_span_context().trace_id)}"
enrich_span_with_trace(span, trace)
Expand All @@ -396,9 +447,7 @@ async def wrapped(*args, **kwargs):
Tracer.push(trace)
enrich_span_with_input(span, trace.inputs)
output = await func(*args, **kwargs)
output = enrich_span_with_trace_type(span, trace.inputs, output, trace_type)
span.set_status(StatusCode.OK)
output = Tracer.pop(output)
output = handle_output(span, trace.inputs, output, trace_type)
except Exception as e:
Tracer.pop(None, e)
raise
Expand Down Expand Up @@ -449,7 +498,7 @@ def wrapped(*args, **kwargs):
span_name = get_node_name_from_context(used_for_span_name=True) or trace.name
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(span_name) as span, _record_cancellation_exceptions_to_span(span):
with start_as_current_span(otel_tracer, span_name) as span:
# Store otel trace id in context for correlation
OperationContext.get_instance()["otel_trace_id"] = f"0x{format_trace_id(span.get_span_context().trace_id)}"
enrich_span_with_trace(span, trace)
Expand All @@ -462,9 +511,7 @@ def wrapped(*args, **kwargs):
Tracer.push(trace)
enrich_span_with_input(span, trace.inputs)
output = func(*args, **kwargs)
output = enrich_span_with_trace_type(span, trace.inputs, output, trace_type)
span.set_status(StatusCode.OK)
output = Tracer.pop(output)
output = handle_output(span, trace.inputs, output, trace_type)
except Exception as e:
Tracer.pop(None, e)
raise
Expand Down
59 changes: 26 additions & 33 deletions src/promptflow-tracing/tests/e2etests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ def assert_otel_trace(self, func, inputs, expected_span_length):
},
2,
),
(prompt_tpl_completion, {"prompt_tpl": "Hello {{name}}", "name": "world", "stream": True}, 3),
(prompt_tpl_completion, {"prompt_tpl": "Hello {{name}}", "name": "world", "stream": True}, 2),
(
prompt_tpl_chat,
{
"prompt_tpl": "system:\nYou are a helpful assistant.\n\n\nuser:\n{{question}}",
"question": "What is ChatGPT?",
"stream": True,
},
3,
2,
),
],
)
Expand Down Expand Up @@ -169,12 +169,12 @@ def assert_otel_traces_with_prompt(self, dev_connections, func, inputs, expected
[
(openai_chat, {"prompt": "Hello"}, 2),
(openai_completion, {"prompt": "Hello"}, 2),
(openai_chat, {"prompt": "Hello", "stream": True}, 3),
(openai_completion, {"prompt": "Hello", "stream": True}, 3),
(openai_chat, {"prompt": "Hello", "stream": True}, 2),
(openai_completion, {"prompt": "Hello", "stream": True}, 2),
(openai_chat_async, {"prompt": "Hello"}, 2),
(openai_completion_async, {"prompt": "Hello"}, 2),
(openai_chat_async, {"prompt": "Hello", "stream": True}, 3),
(openai_completion_async, {"prompt": "Hello", "stream": True}, 3),
(openai_chat_async, {"prompt": "Hello", "stream": True}, 2),
(openai_completion_async, {"prompt": "Hello", "stream": True}, 2),
],
)
def test_otel_trace_with_llm(self, dev_connections, func, inputs, expected_span_length):
Expand Down Expand Up @@ -315,35 +315,15 @@ def validate_span_events(self, span):
assert FUNCTION_OUTPUT_EVENT in events, f"Expected '{FUNCTION_OUTPUT_EVENT}' in events"

if span.attributes[SPAN_TYPE_ATTRIBUTE] == TraceType.LLM:
self.validate_llm_event(span, events)
self.validate_llm_event(events)
elif span.attributes[SPAN_TYPE_ATTRIBUTE] == TraceType.EMBEDDING:
self.validate_embedding_event(events)

if PROMPT_TEMPLATE_EVENT in events:
self.validate_prompt_template_event(events)

def validate_llm_event(self, span, span_events):
is_stream = span_events[FUNCTION_INPUTS_EVENT].get("stream", False)
is_iterated_span = span.name.startswith(ITERATED_SPAN_PREFIX)

# iterate span should have LLM_GENERATED_MESSAGE_EVENT and links
if is_iterated_span:
assert (
LLM_GENERATED_MESSAGE_EVENT in span_events
), f"Expected '{LLM_GENERATED_MESSAGE_EVENT}' in iterated span events"
assert span.links, "Expected links in iterated span"

# non-stream span should have LLM_GENERATED_MESSAGE_EVENT
if not is_stream:
assert (
LLM_GENERATED_MESSAGE_EVENT in span_events
), f"Expected '{LLM_GENERATED_MESSAGE_EVENT}' in non-stream span events"

# original span in streaming mode should not have LLM_GENERATED_MESSAGE_EVENT
if is_stream and not is_iterated_span:
assert (
LLM_GENERATED_MESSAGE_EVENT not in span_events
), f"Unexpected '{LLM_GENERATED_MESSAGE_EVENT}' in original span in streaming mode"
def validate_llm_event(self, span_events):
assert LLM_GENERATED_MESSAGE_EVENT in span_events, f"Expected '{LLM_GENERATED_MESSAGE_EVENT}' in span events"

def validate_embedding_event(self, span_events):
assert EMBEDDING_EVENT in span_events, f"Expected '{EMBEDDING_EVENT}' in span events"
Expand Down Expand Up @@ -446,10 +426,23 @@ def validate_openai_tokens(self, span_list, is_stream=False):
assert span.attributes[token_name] == expected_tokens[span_id][token_name]

def _is_llm_span_with_tokens(self, span, is_stream):
# For streaming mode, there are two spans for openai api call, one is the original span, and the other
# is the iterated span, which name is "Iterated(<original_trace_name>)", we should check the iterated span
# in streaming mode.
"""
This function checks if a given span is a LLM span with tokens.
If in stream mode, the function checks if the span has attributes indicating it's an iterated span.
In non-stream mode, it simply checks if the span's function attribute is in the list of LLM function names.
Args:
span: The span to check.
is_stream: A boolean indicating whether the span is in stream mode.
Returns:
A boolean indicating whether the span is a LLM span with tokens.
"""
if is_stream:
return span.attributes.get("function", "") in LLM_FUNCTION_NAMES and span.name.startswith("Iterated(")
return (
span.attributes.get("function", "") in LLM_FUNCTION_NAMES
and span.attributes.get("output_type", "") == "iterated"
)
else:
return span.attributes.get("function", "") in LLM_FUNCTION_NAMES

0 comments on commit cd07d6d

Please sign in to comment.