Skip to content

Commit

Permalink
Add timeout task to async scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Heyi committed May 10, 2024
1 parent 873d29f commit 3b78c1f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
30 changes: 25 additions & 5 deletions src/promptflow-core/promptflow/executor/_async_nodes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import traceback
from asyncio import Task
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

from promptflow._core.flow_execution_context import FlowExecutionContext
from promptflow._core.tools_manager import ToolsManager
Expand All @@ -20,7 +20,7 @@
from promptflow._utils.utils import extract_user_frame_summaries, set_context, try_get_long_running_logging_interval
from promptflow.contracts.flow import Node
from promptflow.executor._dag_manager import DAGManager
from promptflow.executor._errors import NoNodeExecutedError
from promptflow.executor._errors import LineExecutionTimeoutError, NoNodeExecutedError

PF_ASYNC_NODE_SCHEDULER_EXECUTE_TASK_NAME = "_pf_async_nodes_scheduler.execute"
DEFAULT_TASK_LOGGING_INTERVAL = 60
Expand All @@ -44,6 +44,7 @@ async def execute(
nodes: List[Node],
inputs: Dict[str, Any],
context: FlowExecutionContext,
timeout_seconds: Optional[int] = None,
) -> Tuple[dict, dict]:
# Semaphore should be created in the loop, otherwise it will not work.
loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -75,7 +76,7 @@ async def execute(
# Then the event loop will wait for all tasks to be completed before raising the cancellation error.
# See reference: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor
try:
outputs = await self._execute_with_thread_pool(executor, nodes, inputs, context)
outputs = await self._execute_with_thread_pool(executor, nodes, inputs, context, timeout_seconds)
except asyncio.CancelledError:
await self.cancel()
raise
Expand All @@ -88,12 +89,16 @@ async def _execute_with_thread_pool(
nodes: List[Node],
inputs: Dict[str, Any],
context: FlowExecutionContext,
timeout_seconds: Optional[int] = None,
) -> Tuple[dict, dict]:
flow_logger.info(f"Start to run {len(nodes)} nodes with the current event loop.")
dag_manager = DAGManager(nodes, inputs)
timeout_task = asyncio.create_task(asyncio.sleep(timeout_seconds)) if timeout_seconds else None
task2nodes = self._execute_nodes(dag_manager, context, executor)
while not dag_manager.completed():
task2nodes = await self._wait_and_complete_nodes(task2nodes, dag_manager)
task2nodes = await self._wait_and_complete_nodes(
task2nodes, dag_manager, context, timeout_task, timeout_seconds
)
submitted_tasks2nodes = self._execute_nodes(dag_manager, context, executor)
task2nodes.update(submitted_tasks2nodes)
# Set the event to notify the monitor thread to exit
Expand All @@ -103,16 +108,31 @@ async def _execute_with_thread_pool(
dag_manager.completed_nodes_outputs[node] = None
return dag_manager.completed_nodes_outputs, dag_manager.bypassed_nodes

async def _wait_and_complete_nodes(self, task2nodes: Dict[Task, Node], dag_manager: DAGManager) -> Dict[Task, Node]:
async def _wait_and_complete_nodes(
self,
task2nodes: Dict[Task, Node],
dag_manager: DAGManager,
context: FlowExecutionContext,
timeout_task: Optional[Task] = None,
line_timeout_sec: Optional[int] = None,
) -> Dict[Task, Node]:
if not task2nodes:
raise NoNodeExecutedError("No nodes are ready for execution, but the flow is not completed.")
tasks = [task for task in task2nodes]
for task in tasks:
self._task_start_time[task] = time.time()
if timeout_task is not None:
tasks.append(timeout_task)
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done = [task for task in done if task != timeout_task]
dag_manager.complete_nodes({task2nodes[task].name: task.result() for task in done})
for task in done:
del task2nodes[task]
if timeout_task is not None and timeout_task.done():
for task in tasks:
if not task.done():
task.cancel()
raise LineExecutionTimeoutError(context._line_number, line_timeout_sec)
return task2nodes

def _execute_nodes(
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ async def _traverse_nodes_async(self, inputs, context: FlowExecutionContext) ->
batch_nodes = [node for node in self._flow.nodes if not node.aggregation]
flow_logger.info("Start executing nodes in async mode.")
scheduler = AsyncNodesScheduler(self._tools_manager, self._node_concurrency)
nodes_outputs, bypassed_nodes = await scheduler.execute(batch_nodes, inputs, context)
nodes_outputs, bypassed_nodes = await scheduler.execute(batch_nodes, inputs, context, self._line_timeout_sec)
outputs = self._extract_outputs(nodes_outputs, bypassed_nodes, inputs)
return outputs, nodes_outputs

Expand Down

0 comments on commit 3b78c1f

Please sign in to comment.