Skip to content

Commit

Permalink
Update python sdk to v0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubno committed Aug 14, 2023
1 parent 880b6ca commit a198f1e
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 43 deletions.
1 change: 1 addition & 0 deletions sdk/python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ generate:
mv agent_protocol/main.py agent_protocol/server.py
rm -rf agent_protocol/routers
rm agent_protocol/dependencies.py
black .
5 changes: 3 additions & 2 deletions sdk/python/agent_protocol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from .agent import Agent, StepHandler, TaskHandler, base_router as router
from .models import Artifact, StepRequestBody, TaskRequestBody
from .models import Artifact, Status, StepRequestBody, TaskRequestBody
from .db import Step, Task, TaskDB


__all__ = [
"Agent",
"Artifact",
"Status",
"Step",
"StepHandler",
"StepRequestBody",
"Task",
"TaskDB",
"StepRequestBody",
"TaskHandler",
"TaskRequestBody",
"router",
Expand Down
23 changes: 14 additions & 9 deletions sdk/python/agent_protocol/agent.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
import asyncio
import os
from uuid import uuid4

import aiofiles
from fastapi import APIRouter, UploadFile, Form, File
from fastapi.responses import FileResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from typing import Awaitable, Callable, List, Optional, Annotated
from typing import Callable, List, Optional, Annotated, Coroutine, Any

from .db import InMemoryTaskDB, TaskDB
from .db import InMemoryTaskDB, Task, TaskDB, Step
from .server import app
from .models import (
TaskRequestBody,
Step,
StepRequestBody,
Artifact,
Task,
Status,
)


StepHandler = Callable[[Step], Awaitable[Step]]
TaskHandler = Callable[[Task], Awaitable[None]]
StepHandler = Callable[[Step], Coroutine[Any, Any, Step]]
TaskHandler = Callable[[Task], Coroutine[Any, Any, None]]


_task_handler: Optional[TaskHandler]
Expand Down Expand Up @@ -89,12 +88,17 @@ async def execute_agent_task_step(
"""
Execute a step in the specified agent task.
"""
if not _step_handler:
raise Exception("Step handler not defined")

task = await Agent.db.get_task(task_id)
step = next(filter(lambda x: x.status == Status.created, task.steps), None)

if not step:
raise Exception("No steps to execute")

step.status = Status.running

step.input = body.input if body else None
step.additional_input = body.additional_input if body else None

Expand All @@ -109,7 +113,7 @@ async def execute_agent_task_step(
response_model=Step,
tags=["agent"],
)
async def get_agent_task_step(task_id: str, step_id: str = ...) -> Step:
async def get_agent_task_step(task_id: str, step_id: str) -> Step:
"""
Get details about a specified task step.
"""
Expand Down Expand Up @@ -142,14 +146,15 @@ async def upload_agent_task_artifacts(
"""
Upload an artifact for the specified task.
"""
file_name = file.filename or str(uuid4())
await Agent.db.get_task(task_id)
artifact = await Agent.db.create_artifact(task_id, file.filename, relative_path)
artifact = await Agent.db.create_artifact(task_id, file_name, relative_path)

path = Agent.get_artifact_folder(task_id, artifact)
if not os.path.exists(path):
os.makedirs(path)

async with aiofiles.open(os.path.join(path, file.filename), "wb") as f:
async with aiofiles.open(os.path.join(path, file_name), "wb") as f:
while content := await file.read(1024 * 1024): # async read chunk ~1MiB
await f.write(content)

Expand Down
47 changes: 34 additions & 13 deletions sdk/python/agent_protocol/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,32 @@ class Task(APITask):
steps: List[Step] = []


class NotFoundException(Exception):
"""
Exception raised when a resource is not found.
"""

def __init__(self, item_name: str, item_id: str):
self.item_name = item_name
self.item_id = item_id
super().__init__(f"{item_name} with {item_id} not found.")


class TaskDB(ABC):
async def create_task(
self,
input: Optional[str],
additional_input: Optional[str] = None,
artifacts: List[Artifact] = None,
steps: List[Step] = None,
additional_input: Any = None,
artifacts: Optional[List[Artifact]] = None,
steps: Optional[List[Step]] = None,
) -> Task:
raise NotImplementedError

async def create_step(
self,
task_id: str,
name: Optional[str] = None,
input: Optional[str] = None,
is_last: bool = False,
additional_properties: Optional[Dict[str, str]] = None,
) -> Step:
Expand All @@ -52,7 +64,9 @@ async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
async def list_tasks(self) -> List[Task]:
raise NotImplementedError

async def list_steps(self, task_id: str) -> List[Step]:
async def list_steps(
self, task_id: str, status: Optional[Status] = None
) -> List[Step]:
raise NotImplementedError


Expand All @@ -62,9 +76,9 @@ class InMemoryTaskDB(TaskDB):
async def create_task(
self,
input: Optional[str],
additional_input: Optional[str] = None,
artifacts: List[Artifact] = None,
steps: List[Step] = None,
additional_input: Any = None,
artifacts: Optional[List[Artifact]] = None,
steps: Optional[List[Step]] = None,
) -> Task:
if not steps:
steps = []
Expand All @@ -85,14 +99,16 @@ async def create_step(
self,
task_id: str,
name: Optional[str] = None,
input: Optional[str] = None,
is_last=False,
additional_properties: Dict[str, Any] = None,
additional_properties: Optional[Dict[str, Any]] = None,
) -> Step:
step_id = str(uuid.uuid4())
step = Step(
task_id=task_id,
step_id=step_id,
name=name,
input=input,
status=Status.created,
is_last=is_last,
additional_properties=additional_properties,
Expand All @@ -104,14 +120,14 @@ async def create_step(
async def get_task(self, task_id: str) -> Task:
task = self._tasks.get(task_id, None)
if not task:
raise Exception(f"Task with id {task_id} not found")
raise NotFoundException("Task", task_id)
return task

async def get_step(self, task_id: str, step_id: str) -> Step:
task = await self.get_task(task_id)
step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
if not step:
raise Exception(f"Step with id {step_id} not found")
raise NotFoundException("Step", step_id)
return step

async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
Expand All @@ -120,7 +136,7 @@ async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
)
if not artifact:
raise Exception(f"Artifact with id {artifact_id} not found")
raise NotFoundException("Artifact", artifact_id)
return artifact

async def create_artifact(
Expand All @@ -146,6 +162,11 @@ async def create_artifact(
async def list_tasks(self) -> List[Task]:
return [task for task in self._tasks.values()]

async def list_steps(self, task_id: str) -> List[Step]:
async def list_steps(
self, task_id: str, status: Optional[Status] = None
) -> List[Step]:
task = await self.get_task(task_id)
return [step for step in task.steps]
steps = task.steps
if status:
steps = list(filter(lambda s: s.status == status, steps))
return steps
13 changes: 13 additions & 0 deletions sdk/python/agent_protocol/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from fastapi import Request
from fastapi.responses import PlainTextResponse

from agent_protocol.db import NotFoundException


async def not_found_exception_handler(
request: Request, exc: NotFoundException
) -> PlainTextResponse:
return PlainTextResponse(
str(exc),
status_code=404,
)
78 changes: 62 additions & 16 deletions sdk/python/agent_protocol/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by fastapi-codegen:
# filename: ../../openapi.yml
# timestamp: 2023-08-07T12:14:43+00:00
# timestamp: 2023-08-11T14:24:22+00:00

from __future__ import annotations

Expand All @@ -12,65 +12,111 @@

class TaskInput(BaseModel):
__root__: Any = Field(
..., description="Input parameters for the task. Any value is allowed."
...,
description="Input parameters for the task. Any value is allowed.",
example='{\n"debug": false,\n"mode": "benchmarks"\n}',
)


class Artifact(BaseModel):
artifact_id: str = Field(..., description="ID of the artifact.")
file_name: str = Field(..., description="Filename of the artifact.")
artifact_id: str = Field(
...,
description="ID of the artifact.",
example="b225e278-8b4c-4f99-a696-8facf19f0e56",
)
file_name: str = Field(
..., description="Filename of the artifact.", example="main.py"
)
relative_path: Optional[str] = Field(
None, description="Relative path of the artifact in the agent's workspace."
None,
description="Relative path of the artifact in the agent's workspace.",
example="python/code/",
)


class ArtifactUpload(BaseModel):
file: bytes = Field(..., description="File to upload.")
relative_path: Optional[str] = Field(
None, description="Relative path of the artifact in the agent's workspace."
None,
description="Relative path of the artifact in the agent's workspace.",
example="python/code",
)


class StepInput(BaseModel):
__root__: Any = Field(
..., description="Input parameters for the task step. Any value is allowed."
...,
description="Input parameters for the task step. Any value is allowed.",
example='{\n"file_to_refactor": "models.py"\n}',
)


class StepOutput(BaseModel):
__root__: Any = Field(
..., description="Output that the task step has produced. Any value is allowed."
...,
description="Output that the task step has produced. Any value is allowed.",
example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}',
)


class TaskRequestBody(BaseModel):
input: Optional[str] = Field(None, description="Input prompt for the task.")
input: Optional[str] = Field(
None,
description="Input prompt for the task.",
example="Write the words you receive to the file 'output.txt'.",
)
additional_input: Optional[TaskInput] = None


class Task(TaskRequestBody):
task_id: str = Field(..., description="The ID of the task.")
task_id: str = Field(
...,
description="The ID of the task.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
)
artifacts: List[Artifact] = Field(
[], description="A list of artifacts that the task has produced."
[],
description="A list of artifacts that the task has produced.",
example=[
"7a49f31c-f9c6-4346-a22c-e32bc5af4d8e",
"ab7b4091-2560-4692-a4fe-d831ea3ca7d6",
],
)


class StepRequestBody(BaseModel):
input: Optional[str] = Field(None, description="Input prompt for the step.")
input: Optional[str] = Field(
None, description="Input prompt for the step.", example="Washington"
)
additional_input: Optional[StepInput] = None


class Status(Enum):
created = "created"
running = "running"
completed = "completed"


class Step(StepRequestBody):
task_id: str = Field(..., description="The ID of the task this step belongs to.")
step_id: str = Field(..., description="The ID of the task step.")
name: Optional[str] = Field(None, description="The name of the task step.")
task_id: str = Field(
...,
description="The ID of the task this step belongs to.",
example="50da533e-3904-4401-8a07-c49adf88b5eb",
)
step_id: str = Field(
...,
description="The ID of the task step.",
example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
)
name: Optional[str] = Field(
None, description="The name of the task step.", example="Write to file"
)
status: Status = Field(..., description="The status of the task step.")
output: Optional[str] = Field(None, description="Output of the task step.")
output: Optional[str] = Field(
None,
description="Output of the task step.",
example="I am going to use the write_to_file command and write Washington to a file called output.txt <write_to_file('output.txt', 'Washington')",
)
additional_output: Optional[StepOutput] = None
artifacts: List[Artifact] = Field(
[], description="A list of artifacts that the step has produced."
Expand Down
7 changes: 6 additions & 1 deletion sdk/python/agent_protocol/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

from fastapi import FastAPI

from agent_protocol.db import NotFoundException
from agent_protocol.middlewares import not_found_exception_handler

app = FastAPI(
title="Agent Communication Protocol",
description="Specification of the API protocol for communication with an agent.",
version="v0.2",
version="v0.3",
)

app.add_exception_handler(NotFoundException, not_found_exception_handler)
2 changes: 1 addition & 1 deletion sdk/python/examples/smol_developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def task_handler(task: Task) -> None:
await Agent.db.create_step(task.task_id, StepTypes.PLAN)


async def step_handler(step: Step):
async def step_handler(step: Step) -> Step:
task = await Agent.db.get_task(step.task_id)
if step.name == StepTypes.PLAN:
return await _generate_shared_deps(step)
Expand Down
Loading

0 comments on commit a198f1e

Please sign in to comment.