Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create and clean up kafka topics manually to reduce memory burden for Adala server #107

Merged
merged 15 commits into from
May 21, 2024
Merged
4 changes: 3 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ services:
- ALLOW_PLAINTEXT_LISTENER=yes
- KAFKA_CFG_NODE_ID=1
- KAFKA_KRAFT_CLUSTER_ID=MkU3OEVBNTcwNTJENDM2Qk
- KAFKA_CFG_AUTO_CREATE_TOPICS_ENABLE=false
app:
build:
context: .
Expand All @@ -30,8 +31,9 @@ services:
redis:
condition: service_healthy
environment:
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093
- REDIS_URL=redis://redis:6379/0
- KAFKA_BOOTSTRAP_SERVERS=kafka:9093 # TODO pull from .env
- KAFKA_RETENTION_MS=180000 # TODO pull from .env
command:
["poetry", "run", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
worker:
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ celery = {version = "^5.3.6", extras = ["redis"]}
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = "^0.0.32"
kafka-python = "^2.0.2"

[tool.poetry.dev-dependencies]
pytest = "^7.4.3"
Expand Down
3 changes: 3 additions & 0 deletions server/.env.example
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
KAFKA_BOOTSTRAP_SERVERS="localhost:9093"

# this value is only for local dev. In our deployments, it is not set here, but in another place: https://github.com/HumanSignal/infra/pull/67
KAFKA_RETENTION_MS=180000 # 30 minutes
4 changes: 2 additions & 2 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
process_streaming_output,
streaming_parent_task,
)
from utils import get_input_topic, Settings
from utils import get_input_topic_name, Settings
from server.handlers.result_handlers import ResultHandler


Expand Down Expand Up @@ -230,7 +230,7 @@ async def submit_batch(batch: BatchData):
Response: Generic response indicating status of request
"""

topic = get_input_topic(batch.job_id)
topic = get_input_topic_name(batch.job_id)
producer = AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
Expand Down
56 changes: 33 additions & 23 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from aiokafka.errors import UnknownTopicOrPartitionError
from celery import Celery, states
from celery.exceptions import Ignore
from server.utils import get_input_topic, get_output_topic, Settings
from server.utils import (
get_input_topic_name,
get_output_topic_name,
ensure_topic,
delete_topic,
Settings,
)
from server.handlers.result_handlers import ResultHandler


Expand Down Expand Up @@ -54,21 +60,29 @@ def streaming_parent_task(
# Parent job ID is used for input/output topic names
parent_job_id = self.request.id

# Override kafka_bootstrap_servers with value from settings
# create kafka topics
input_topic_name = get_input_topic_name(parent_job_id)
ensure_topic(input_topic_name)
output_topic_name = get_output_topic_name(parent_job_id)
ensure_topic(output_topic_name)

# Override default agent kafka settings
settings = Settings()
agent.environment.kafka_bootstrap_servers = settings.kafka_bootstrap_servers
agent.environment.kafka_input_topic = input_topic_name
agent.environment.kafka_output_topic = output_topic_name

inference_task = process_file_streaming
logger.info(f"Submitting task {inference_task.name} with agent {agent}")
input_result = inference_task.delay(agent=agent, parent_job_id=parent_job_id)
input_result = inference_task.delay(agent=agent)
input_job_id = input_result.id
logger.info(f"Task {inference_task.name} submitted with job_id {input_job_id}")

result_handler_task = process_streaming_output
logger.info(f"Submitting task {result_handler_task.name}")
output_result = result_handler_task.delay(
input_job_id=input_job_id,
parent_job_id=parent_job_id,
output_topic_name=output_topic_name,
result_handler=result_handler,
batch_size=batch_size,
)
Expand All @@ -95,6 +109,10 @@ def streaming_parent_task(
):
time.sleep(1)

# clean up kafka topics
delete_topic(input_topic_name)
delete_topic(output_topic_name)

logger.info("Both input and output jobs complete")

# Update parent task status to SUCCESS and pass metadata again
Expand All @@ -109,45 +127,40 @@ def streaming_parent_task(
raise Ignore()


@app.task(
name="process_file_streaming", track_started=True, bind=True, serializer="pickle"
)
def process_file_streaming(self, agent: Agent, parent_job_id: str):
# Set input and output topics using parent job ID
agent.environment.kafka_input_topic = get_input_topic(parent_job_id)
agent.environment.kafka_output_topic = get_output_topic(parent_job_id)
@app.task(name="process_file_streaming", track_started=True, serializer="pickle")
def process_file_streaming(agent: Agent):
# agent's kafka_bootstrap servers and kafka topics should be set in parent task

# Run the agent
asyncio.run(agent.arun())


async def async_process_streaming_output(
input_job_id: str,
parent_job_id: str,
output_topic_name,
result_handler: ResultHandler,
batch_size: int,
):
logger.info(f"Polling for results {parent_job_id=}")
logger.info(f"Polling for results {output_topic_name=}")

topic = get_output_topic(parent_job_id)
settings = Settings()

# Retry to workaround race condition of topic creation
retries = 5
while retries > 0:
try:
consumer = AIOKafkaConsumer(
topic,
output_topic_name,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
)
await consumer.start()
logger.info(f"consumer started {parent_job_id=}")
logger.info(f"consumer started {output_topic_name=}")
break
except UnknownTopicOrPartitionError as e:
logger.error(msg=e)
logger.info(f"Retrying to create consumer with topic {topic}")
logger.info(f"Retrying to create consumer with topic {output_topic_name}")

await consumer.stop()
retries -= 1
Expand Down Expand Up @@ -183,20 +196,17 @@ async def async_process_streaming_output(
await consumer.stop()


@app.task(
name="process_streaming_output", track_started=True, bind=True, serializer="pickle"
)
@app.task(name="process_streaming_output", track_started=True, serializer="pickle")
def process_streaming_output(
self,
input_job_id: str,
parent_job_id: str,
output_topic_name: str,
result_handler: ResultHandler,
batch_size: int,
):
try:
asyncio.run(
async_process_streaming_output(
input_job_id, parent_job_id, result_handler, batch_size
input_job_id, output_topic_name, result_handler, batch_size
)
)
except Exception as e:
Expand Down
49 changes: 45 additions & 4 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Union
from pathlib import Path
from kafka.admin import KafkaAdminClient, NewTopic
from kafka.errors import TopicAlreadyExistsError


class Settings(BaseSettings):
Expand All @@ -10,16 +12,55 @@ class Settings(BaseSettings):
"""

kafka_bootstrap_servers: Union[str, List[str]]
kafka_retention_ms: int

model_config = SettingsConfigDict(
# have to use an absolute path here so celery workers can find it
env_file=(Path(__file__).parent / ".env"),
)


def get_input_topic(job_id: str):
return f"adala-input-{job_id}"
def get_input_topic_name(job_id: str):
topic_name = f"adala-input-{job_id}"

return topic_name

def get_output_topic(job_id: str):
return f"adala-output-{job_id}"

def get_output_topic_name(job_id: str):
topic_name = f"adala-output-{job_id}"

return topic_name


def ensure_topic(topic_name: str):
settings = Settings()
bootstrap_servers = settings.kafka_bootstrap_servers
retention_ms = settings.kafka_retention_ms

admin_client = KafkaAdminClient(
bootstrap_servers=bootstrap_servers, client_id="topic_creator"
)

topic = NewTopic(
name=topic_name,
num_partitions=1,
replication_factor=1,
topic_configs={"retention.ms": str(retention_ms)},
)

try:
admin_client.create_topics(new_topics=[topic])
except TopicAlreadyExistsError:
# we shouldn't hit this case when KAFKA_CFG_AUTO_CREATE_TOPICS=false unless there is a legitimate name collision, so should raise here after testing
pass


def delete_topic(topic_name: str):
settings = Settings()
bootstrap_servers = settings.kafka_bootstrap_servers

admin_client = KafkaAdminClient(
bootstrap_servers=bootstrap_servers, client_id="topic_deleter"
)

admin_client.delete_topics(topics=[topic_name])