Skip to content

Commit

Permalink
core: Tap output of sync iterators for astream_events (#21842)
Browse files Browse the repository at this point in the history
Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
  • Loading branch information
nfcampos committed May 17, 2024
1 parent 9a39f92 commit b1e7b40
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 30 deletions.
14 changes: 14 additions & 0 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,9 @@ def _transform_stream_with_config(
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler

# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
Expand All @@ -1742,6 +1745,17 @@ def _transform_stream_with_config(
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers
# instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator)
try:
while True:
chunk: Output = context.run(next, iterator) # type: ignore
Expand Down
6 changes: 5 additions & 1 deletion libs/core/langchain_core/tracers/_streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Internal tracers used for stream_log and astream events implementations."""
import abc
from typing import AsyncIterator, TypeVar
from typing import AsyncIterator, Iterator, TypeVar
from uuid import UUID

T = TypeVar("T")
Expand All @@ -22,6 +22,10 @@ def tap_output_aiter(
) -> AsyncIterator[T]:
"""Used for internal astream_log and astream events implementations."""

@abc.abstractmethod
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Used for internal astream_log and astream events implementations."""


__all__ = [
"_StreamingCallbackHandler",
Expand Down
46 changes: 33 additions & 13 deletions libs/core/langchain_core/tracers/event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -102,10 +103,10 @@ def __init__(
self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream()

async def _send(self, event: StreamEvent, event_type: str) -> None:
def _send(self, event: StreamEvent, event_type: str) -> None:
"""Send an event to the stream."""
if self.root_event_filter.include_event(event, event_type):
await self.send_stream.send(event)
self.send_stream.send_nowait(event)

def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate over the receive stream."""
Expand All @@ -119,7 +120,26 @@ async def tap_output_aiter(
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
await self._send(
self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
yield chunk

def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap the output aiter."""
for chunk in output:
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
Expand Down Expand Up @@ -155,7 +175,7 @@ async def on_chat_model_start(
"inputs": {"messages": messages},
}

await self._send(
self._send(
{
"event": "on_chat_model_start",
"data": {
Expand Down Expand Up @@ -192,7 +212,7 @@ async def on_llm_start(
"inputs": {"prompts": prompts},
}

await self._send(
self._send(
{
"event": "on_llm_start",
"data": {
Expand Down Expand Up @@ -241,7 +261,7 @@ async def on_llm_new_token(
else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}")

await self._send(
self._send(
{
"event": event,
"data": {
Expand Down Expand Up @@ -295,7 +315,7 @@ async def on_llm_end(
else:
raise ValueError(f"Unexpected run type: {run_info['run_type']}")

await self._send(
self._send(
{
"event": event,
"data": {"output": output, "input": inputs_},
Expand Down Expand Up @@ -340,7 +360,7 @@ async def on_chain_start(

self.run_map[run_id] = run_info

await self._send(
self._send(
{
"event": f"on_{run_type_}_start",
"data": data,
Expand Down Expand Up @@ -373,7 +393,7 @@ async def on_chain_end(
"input": inputs,
}

await self._send(
self._send(
{
"event": event,
"data": data,
Expand Down Expand Up @@ -408,7 +428,7 @@ async def on_tool_start(
"inputs": inputs,
}

await self._send(
self._send(
{
"event": "on_tool_start",
"data": {
Expand All @@ -432,7 +452,7 @@ async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None
)
inputs = run_info["inputs"]

await self._send(
self._send(
{
"event": "on_tool_end",
"data": {
Expand Down Expand Up @@ -470,7 +490,7 @@ async def on_retriever_start(
"inputs": {"query": query},
}

await self._send(
self._send(
{
"event": "on_retriever_start",
"data": {
Expand All @@ -492,7 +512,7 @@ async def on_retriever_end(
"""Run when Retriever ends running."""
run_info = self.run_map.pop(run_id)

await self._send(
self._send(
{
"event": "on_retriever_end",
"data": {
Expand Down
20 changes: 20 additions & 0 deletions libs/core/langchain_core/tracers/log_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
Expand Down Expand Up @@ -252,6 +253,25 @@ async def tap_output_aiter(

yield chunk

def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap an output async iterator to stream its values to the log."""
for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
if not self.send(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
):
break

yield chunk

def include_run(self, run: Run) -> bool:
if run.id == self.root_id:
return False
Expand Down
22 changes: 6 additions & 16 deletions libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,27 +1650,22 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
]


@pytest.mark.xfail(
reason="This test is failing due to missing functionality."
"Need to implement logic in _transform_stream_with_config that mimics the async "
"variant that uses tap_output_iter"
)
async def test_sync_in_async_stream_lambdas() -> None:
"""Test invoking nested runnable lambda."""

def add_one_(x: int) -> int:
def add_one(x: int) -> int:
return x + 1

add_one = RunnableLambda(add_one_)
add_one_ = RunnableLambda(add_one)

async def add_one_proxy_(x: int, config: RunnableConfig) -> int:
streaming = add_one.stream(x, config)
async def add_one_proxy(x: int, config: RunnableConfig) -> int:
streaming = add_one_.stream(x, config)
results = [result for result in streaming]
return results[0]

add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore
add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore

events = await _collect_events(add_one_proxy.astream_events(1, version="v2"))
events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
assert events == EXPECTED_EVENTS


Expand All @@ -1694,11 +1689,6 @@ async def add_one_proxy(x: int, config: RunnableConfig) -> int:
assert events == EXPECTED_EVENTS


@pytest.mark.xfail(
reason="This test is failing due to missing functionality."
"Need to implement logic in _transform_stream_with_config that mimics the async "
"variant that uses tap_output_iter"
)
async def test_sync_in_sync_lambdas() -> None:
"""Test invoking nested runnable lambda."""

Expand Down

0 comments on commit b1e7b40

Please sign in to comment.