Skip to content

Commit

Permalink
Support context manager generator output
Browse files Browse the repository at this point in the history
  • Loading branch information
ToughAnyway committed May 9, 2024
1 parent aac3bfb commit bbd13d3
Showing 1 changed file with 89 additions and 51 deletions.
140 changes: 89 additions & 51 deletions src/promptflow-tracing/promptflow/tracing/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import inspect
import json
import logging
from collections.abc import AsyncIterator, Iterator
from importlib.metadata import version
from threading import Lock
from typing import Callable, Dict, List, Optional
from typing import AsyncGenerator, Callable, ContextManager, Dict, Generator, List, Optional

import opentelemetry.trace as otel_trace
from opentelemetry.sdk.trace import ReadableSpan
Expand Down Expand Up @@ -162,8 +161,8 @@ def enrich_span_with_trace_type(span, inputs, output, trace_type):


def trace_iterator_if_needed(span, inputs, output):
if isinstance(output, (Iterator, AsyncIterator)) and not isinstance(span, NonRecordingSpan):
trace_func = traced_generator if isinstance(output, Iterator) else traced_async_generator
if isinstance(output, (Generator, AsyncGenerator)) and not isinstance(span, NonRecordingSpan):
trace_func = TracedGenerator if isinstance(output, Generator) else TracedAsyncGenerator
output = trace_func(span, inputs, output)
return output

Expand All @@ -179,53 +178,92 @@ def enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output
token_collector.collect_openai_tokens_for_streaming(span, inputs, generator_output, parser.is_chat)


def traced_generator(original_span: ReadableSpan, inputs, generator):
context = original_span.get_span_context()
link = Link(context)
# If start_trace is not called, the name of the original_span will be empty.
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(
f"Iterated({original_span.name})",
links=[link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, inputs)
generator_proxy = GeneratorProxy(generator)
yield from generator_proxy
generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))
span.set_status(StatusCode.OK)
token_collector.collect_openai_tokens_for_parent_span(span)


async def traced_async_generator(original_span: ReadableSpan, inputs, generator):
context = original_span.get_span_context()
link = Link(context)
# If start_trace is not called, the name of the original_span will be empty.
# need to get everytime to ensure tracer is latest
otel_tracer = otel_trace.get_tracer("promptflow")
with otel_tracer.start_as_current_span(
f"Iterated({original_span.name})",
links=[link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, inputs)
generator_proxy = AsyncGeneratorProxy(generator)
async for item in generator_proxy:
yield item
generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, original_span, inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))
span.set_status(StatusCode.OK)
token_collector.collect_openai_tokens_for_parent_span(span)
class TracedGenerator:
def __init__(self, original_span: ReadableSpan, inputs, generator: Generator):
self.original_span = original_span
self.inputs = inputs
self.generator = generator
self.context = original_span.get_span_context()
self.link = Link(self.context)
self.otel_tracer = otel_trace.get_tracer("promptflow")

def __iter__(self):
return self

def __next__(self):
try:
with self.otel_tracer.start_as_current_span(
f"Iterated({self.original_span.name})",
links=[self.link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, self.original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, self.inputs)
generator_proxy = GeneratorProxy(self.generator)
item = next(generator_proxy)
generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, self.original_span, self.inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(self.original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))
span.set_status(StatusCode.OK)

return item
except StopIteration:
token_collector.collect_openai_tokens_for_parent_span(span)
raise

def __enter__(self):
if isinstance(self.iterator, ContextManager):
self.iterator.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.iterator, ContextManager):
self.iterator.__exit__(exc_type, exc_val, exc_tb)


class TracedAsyncGenerator:
def __init__(self, original_span: ReadableSpan, inputs, generator: AsyncGenerator):
self.original_span = original_span
self.inputs = inputs
self.generator = generator
self.context = original_span.get_span_context()
self.link = Link(self.context)
self.otel_tracer = otel_trace.get_tracer("promptflow")

async def __aiter__(self):
return self

async def __anext__(self):
try:
with self.otel_tracer.start_as_current_span(
f"Iterated({self.original_span.name})",
links=[self.link],
) as span, _record_cancellation_exceptions_to_span(span):
enrich_span_with_original_attributes(span, self.original_span.attributes)
# Enrich the new span with input before generator iteration to prevent loss of input information.
# The input is as an event within this span.
enrich_span_with_input(span, self.inputs)
generator_proxy = AsyncGeneratorProxy(self.generator)
item = await generator_proxy.__anext__()
generator_output = generator_proxy.items
enrich_span_with_llm_if_needed(span, self.original_span, self.inputs, generator_output)
enrich_span_with_openai_tokens(span, TraceType(self.original_span.attributes["span_type"]))
enrich_span_with_output(span, serialize_attribute(generator_output))
span.set_status(StatusCode.OK)

return item
except StopAsyncIteration:
token_collector.collect_openai_tokens_for_parent_span(span)
raise

def __enter__(self):
if isinstance(self.generator, ContextManager):
self.generator.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.generator, ContextManager):
self.generator.__exit__(exc_type, exc_val, exc_tb)


def enrich_span_with_original_attributes(span, attributes):
Expand Down

0 comments on commit bbd13d3

Please sign in to comment.