Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agent 2 sweep ini outline #7480

Open
wants to merge 28 commits into
base: nick/launch-agent-2-poc
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
890c263
ini outline
KyleGoyette Apr 25, 2024
c88fe58
wip
KyleGoyette Apr 25, 2024
95e6ae6
finish base impl
KyleGoyette Apr 25, 2024
1a49bde
Merge branch 'nick/launch-agent-2-poc' into agent2/sweeps-parity
KyleGoyette Apr 25, 2024
62c4a03
format and comment
KyleGoyette Apr 25, 2024
bebf3be
wip
KyleGoyette Apr 25, 2024
ab02a44
fix all controllers
KyleGoyette Apr 25, 2024
0dac70c
Merge branch 'agent2/sweeps-parity' of https://github.com/wandb/wandb…
KyleGoyette Apr 25, 2024
276d6a4
add call to register
KyleGoyette Apr 25, 2024
c3794e3
wip
KyleGoyette Apr 25, 2024
7d03163
remove future
KyleGoyette Apr 25, 2024
81865c7
Merge branch 'nick/launch-agent-2-poc' into agent2/sweeps-parity
KyleGoyette Apr 30, 2024
0c55327
wip
KyleGoyette Apr 30, 2024
7da3c07
rename file
KyleGoyette Apr 30, 2024
4f8a89f
Merge branch 'main' into nick/launch-agent-2-poc
KyleGoyette May 1, 2024
a6fc74f
Merge branch 'nick/launch-agent-2-poc' into agent2/sweeps-parity
KyleGoyette May 1, 2024
f4dc731
fix tests
KyleGoyette May 1, 2024
79765b6
fix
KyleGoyette May 1, 2024
a38c834
fix more
KyleGoyette May 1, 2024
8d7f2d3
fix end to end test
KyleGoyette May 2, 2024
a710f90
rename files
KyleGoyette May 2, 2024
40631b1
fix mypy
KyleGoyette May 2, 2024
6b5aab6
Merge branch 'nick/launch-agent-2-poc' into agent2/sweeps-parity
KyleGoyette May 2, 2024
df4143b
fix agent tests
KyleGoyette May 2, 2024
2d2b25e
fix controller tests for different versions
KyleGoyette May 2, 2024
8659dd6
wip
KyleGoyette May 2, 2024
85e7151
wip
KyleGoyette May 2, 2024
fae52f8
ruff ordering
KyleGoyette May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 65 additions & 6 deletions wandb/sdk/launch/agent2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from wandb.sdk.launch.agent.agent import HIDDEN_AGENT_RUN_TYPE
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
from wandb.sdk.launch.agent.run_queue_item_file_saver import RunQueueItemFileSaver
from wandb.sdk.launch.agent2.controllers.local_process import SchedulerManager
from wandb.sdk.launch.builder.build import construct_agent_configs
from wandb.sdk.launch.environment.local_environment import LocalEnvironment
from wandb.sdk.launch.registry.local_registry import LocalRegistry
from wandb.sdk.launch.utils import PROJECT_SYNCHRONOUS, event_loop_thread_exec

from .controller import LaunchController, LegacyResources
Expand Down Expand Up @@ -68,6 +71,7 @@
self._last_state = None
self._wandb_version: str = "wandb@" + wandb.__version__
self._task: Optional[asyncio.Task[Any]] = None
self._receive_scheduler_job_queue = asyncio.Queue()

Check warning on line 74 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L74

Added line #L74 was not covered by tests

self._logger = logging.getLogger("wandb.launch.agent2")
handler = logging.StreamHandler(sys.stdout)
Expand Down Expand Up @@ -123,7 +127,67 @@
# Start the main agent state poll loop
self.start_poll_loop(event_loop)

def file_saver_factory(job_id):
return RunQueueItemFileSaver(self._wandb_run, job_id)

Check warning on line 131 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L130-L131

Added lines #L130 - L131 were not covered by tests

def job_tracker_factory(job_id, q):
return JobAndRunStatusTracker(job_id, q, file_saver_factory(job_id))

Check warning on line 134 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L133-L134

Added lines #L133 - L134 were not covered by tests

try:
# create sweep scheduler local process controller
# TODO: move into util function
controller_impl = self.get_controller_for_jobset("local-process")
_, build_config, registry_config = construct_agent_configs(

Check warning on line 140 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L139-L140

Added lines #L139 - L140 were not covered by tests
dict(self._config)
)
environment = LocalEnvironment()
registry = LocalRegistry()
runner = loader.runner_from_config(

Check warning on line 145 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L143-L145

Added lines #L143 - L145 were not covered by tests
"local-process",
self._api, # todo factor out (?)
{},
environment,
registry,
)
legacy_resources = LegacyResources(

Check warning on line 152 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L152

Added line #L152 was not covered by tests
self._api,
builder,
registry,
runner,
environment,
job_tracker_factory,
)
controller_logger = self._logger.getChild(

Check warning on line 160 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L160

Added line #L160 was not covered by tests
"controller.sweep-scheduler-local-process"
)
scheduler_controller = controller_impl(

Check warning on line 163 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L163

Added line #L163 was not covered by tests
{
"agent_id": self._id,
"jobset_spec": JobSetSpec(
name="_wandb_sweep_scheduler_local_process",
entity_name=self._config["entity"],
project_name="_wandb_sweep-scheduler_local_process",
),
"jobset_metadata": None,
},
JobSet(self._api, {}, self._id, controller_logger),
controller_logger,
self._shutdown_controllers_event,
legacy_resources,
self._receive_scheduler_job_queue, # TODO: not necessary for sweep scheduler
)
manager_logger = self._logger.getChild("scheduler_manager")
scheduler_manager = SchedulerManager(

Check warning on line 180 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L179-L180

Added lines #L179 - L180 were not covered by tests
scheduler_controller,
self._config["max_schedulers"],
self._receive_scheduler_job_queue,
manager_logger,
)
controller_task: asyncio.Task = asyncio.create_task(

Check warning on line 186 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L186

Added line #L186 was not covered by tests

)
self._launch_controller_tasks.add(controller_task)

Check warning on line 189 in wandb/sdk/launch/agent2/agent.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/agent.py#L189

Added line #L189 was not covered by tests

# Start job set and controller loops
for q in self._config["queues"]:
# Start a JobSet for each queue
Expand Down Expand Up @@ -169,12 +233,6 @@
registry,
)

def file_saver_factory(job_id):
return RunQueueItemFileSaver(self._wandb_run, job_id)

def job_tracker_factory(job_id, q=q):
return JobAndRunStatusTracker(job_id, q, file_saver_factory(job_id))

legacy_resources = LegacyResources(
self._api,
builder,
Expand All @@ -197,6 +255,7 @@
controller_logger,
self._shutdown_controllers_event,
legacy_resources,
self._receive_scheduler_job_queue,
)
)
self._launch_controller_tasks.add(controller_task)
Expand Down
18 changes: 15 additions & 3 deletions wandb/sdk/launch/agent2/controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from wandb.sdk.launch.errors import LaunchError
from wandb.sdk.launch.runner.abstract import AbstractRun, Status

from ...agent.agent import RUN_INFO_GRACE_PERIOD
from ...agent.agent import RUN_INFO_GRACE_PERIOD, _is_scheduler_job
from ...queue_driver.abstract import AbstractQueueDriver
from ...utils import event_loop_thread_exec
from ..controller import LaunchControllerConfig, LegacyResources
from ..jobset import Job, JobSet
from ..jobset import Job, JobSet, JobWithQueue

WANDB_JOBSET_DISCOVERABILITY_LABEL = "_wandb-jobset"

Expand All @@ -41,13 +41,15 @@
jobset: JobSet,
logger: logging.Logger,
legacy: LegacyResources,
scheduler_queue: asyncio.Queue[Tuple[JobWithQueue, asyncio.Future]],
max_concurrency: int,
):
self.config = config
self.jobset = jobset
self.logger = logger
self.legacy = legacy
self.max_concurrency = max_concurrency
self._scheduler_queue = scheduler_queue

Check warning on line 52 in wandb/sdk/launch/agent2/controllers/base.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/base.py#L52

Added line #L52 was not covered by tests

self.id = config["jobset_spec"].name
self.active_runs: Dict[str, RunWithTracker] = {}
Expand Down Expand Up @@ -126,7 +128,7 @@
try:
project = LaunchProject.from_spec(job.run_spec, self.legacy.api)
run_id = project.run_id
job_tracker = self.legacy.job_tracker_factory(run_id)
job_tracker = self.legacy.job_tracker_factory(run_id, project.queue_name)

Check warning on line 131 in wandb/sdk/launch/agent2/controllers/base.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/base.py#L131

Added line #L131 was not covered by tests
job_tracker.update_run_info(project)
except Exception as e:
self.logger.error(
Expand All @@ -140,6 +142,16 @@
project.run_queue_item_id = job.id
project.fetch_and_validate_project()

if (

Check warning on line 145 in wandb/sdk/launch/agent2/controllers/base.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/base.py#L145

Added line #L145 was not covered by tests
_is_scheduler_job(job.run_spec)
and job.run_spec.get("resource") == "local-process"
):
future = asyncio.futures.Future()
await self._scheduler_queue.put((job, future))
res = await future.result()
if res == False:
return None

Check warning on line 153 in wandb/sdk/launch/agent2/controllers/base.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/base.py#L149-L153

Added lines #L149 - L153 were not covered by tests

ack_result = await self.jobset.ack_job(job.id, run_id)
self.logger.info(f"Acked item: {json.dumps(ack_result, indent=2)}")
if not ack_result:
Expand Down
96 changes: 82 additions & 14 deletions wandb/sdk/launch/agent2/controllers/local_process.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
import json
import logging
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple, Union

from ..._project_spec import LaunchProject
from ...queue_driver import passthrough
from ..controller import LaunchControllerConfig, LegacyResources
from ..jobset import Job, JobSet
from ..jobset import Job, JobSet, JobWithQueue
from .base import BaseManager, RunWithTracker


Expand All @@ -16,6 +16,7 @@
logger: logging.Logger,
shutdown_event: asyncio.Event,
legacy: LegacyResources,
agent_queue: asyncio.Queue,
) -> Any:
# disable job set loop because we are going to use the passthrough queue driver
# to drive the launch controller here
Expand All @@ -39,7 +40,9 @@
f"Starting local process controller with max concurrency {max_concurrency}"
)

mgr = LocalProcessManager(config, jobset, logger, legacy, max_concurrency)
mgr = LocalProcessManager(

Check warning on line 43 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L43

Added line #L43 was not covered by tests
config, jobset, logger, legacy, agent_queue, max_concurrency
)

while not shutdown_event.is_set():
await mgr.reconcile()
Expand All @@ -62,6 +65,7 @@
jobset: JobSet,
logger: logging.Logger,
legacy: LegacyResources,
agent_queue: asyncio.Queue,
max_concurrency: int,
):
self.queue_driver: passthrough.PassthroughQueueDriver = (
Expand All @@ -73,7 +77,7 @@
agent_id=config["agent_id"],
)
)
super().__init__(config, jobset, logger, legacy, max_concurrency)
super().__init__(config, jobset, logger, legacy, agent_queue, max_concurrency)

Check warning on line 80 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L80

Added line #L80 was not covered by tests

async def reconcile(self) -> None:
num_runs_needed = self.max_concurrency - len(self.active_runs)
Expand All @@ -90,35 +94,99 @@
async def launch_item(self, item: Job) -> Optional[str]:
self.logger.info(f"Launching item: {item}")

project = LaunchProject.from_spec(item.run_spec, self.legacy.api)
project.queue_name = self.config["jobset_spec"].name
project.queue_entity = self.config["jobset_spec"].entity_name
project.run_queue_item_id = item.id
project = self._populate_project(item)

Check warning on line 97 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L97

Added line #L97 was not covered by tests
project.fetch_and_validate_project()
run_id = await self._launch_job(item, project)
self.logger.info(f"Launched item got run_id: {run_id}")
return run_id

Check warning on line 101 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L99-L101

Added lines #L99 - L101 were not covered by tests

async def launch_scheduler_item(self, item: JobWithQueue) -> Optional[str]:
self.logger.info(f"Launching item: {item}")

Check warning on line 104 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L104

Added line #L104 was not covered by tests

project = self._populate_project(item)
project.fetch_and_validate_project()

Check warning on line 107 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L106-L107

Added lines #L106 - L107 were not covered by tests

run_id = await self._launch_job(item.job, project)
self.logger.info(f"Launched item got run_id: {run_id}")
return run_id

Check warning on line 111 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L109-L111

Added lines #L109 - L111 were not covered by tests

def _populate_project(self, job: Union[Job, JobWithQueue]) -> LaunchProject:
project = None
if isinstance(job, JobWithQueue):
project = LaunchProject.from_spec(job.job.run_spec, self.legacy.api)
queue_name = job.queue
queue_entity = job.entity
job_id = job.job.id

Check warning on line 119 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L114-L119

Added lines #L114 - L119 were not covered by tests
else:
project = LaunchProject.from_spec(job.run_spec, self.legacy.api)
queue_name = self.config["jobset_spec"].name
queue_entity = self.config["jobset_spec"].entity_name
job_id = job.id
project.queue_name = queue_name
project.queue_entity = queue_entity
project.run_queue_item_id = job_id
return project

Check warning on line 128 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L121-L128

Added lines #L121 - L128 were not covered by tests

def _get_job(self, item: Union[Job, JobWithQueue]) -> Job:
KyleGoyette marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(item, JobWithQueue):
return item.job
return item

Check warning on line 133 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L131-L133

Added lines #L131 - L133 were not covered by tests

async def _launch_job(self, job: Job, project: LaunchProject) -> Optional[str]:
KyleGoyette marked this conversation as resolved.
Show resolved Hide resolved
run_id = project.run_id
job_tracker = self.legacy.job_tracker_factory(run_id)
job_tracker = self.legacy.job_tracker_factory(run_id, project.queue_name)

Check warning on line 137 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L137

Added line #L137 was not covered by tests
job_tracker.update_run_info(project)

ack_result = await self.queue_driver.ack_run_queue_item(item.id, run_id)
# note since we ack on rqi id the queue driver will handle acking the run queue item
# even if its not for the specified queue
ack_result = await self.queue_driver.ack_run_queue_item(job.id, run_id)

Check warning on line 142 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L142

Added line #L142 was not covered by tests
if ack_result is None:
self.logger.error(f"Failed to ack item {item.id}")
self.logger.error(f"Failed to ack item {job.id}")

Check warning on line 144 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L144

Added line #L144 was not covered by tests
return None
self.logger.info(f"Acked item: {json.dumps(ack_result, indent=2)}")
run = await self.legacy.runner.run(project, "") # image is unused
if not run:
job_tracker.failed_to_start = True
self.logger.error(f"Failed to start run for item {item.id}")
self.logger.error(f"Failed to start run for item {job.id}")

Check warning on line 150 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L150

Added line #L150 was not covered by tests
raise NotImplementedError("TODO: handle this case")

self.active_runs[item.id] = RunWithTracker(run, job_tracker)
self.active_runs[job.id] = RunWithTracker(run, job_tracker)

Check warning on line 153 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L153

Added line #L153 was not covered by tests

run_id = project.run_id
self.logger.info(f"Launched item got run_id: {run_id}")
return run_id

async def find_orphaned_jobs(self) -> List[Any]:
raise NotImplementedError

def label_job(self, project: LaunchProject) -> None:
pass


class SchedulerManager:
def __init__(
self,
controller: LocalProcessManager,
max_jobs: int,
scheduler_jobs_queue: asyncio.Queue[Tuple[JobWithQueue, asyncio.Future]],
logger: logging.Logger,
):
self._controller = controller
self._scheduler_jobs_queue = scheduler_jobs_queue
self._logger = logger
self._max_jobs = max_jobs

Check warning on line 176 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L173-L176

Added lines #L173 - L176 were not covered by tests

async def poll(self):
while True:
res = await self._scheduler_jobs_queue.get()
if res is None:
asyncio.sleep(5) # TODO: const this
break
job, future = res
if len(self._controller.active_runs) >= self._max_jobs:
self._logger.info(f"Scheduler job queue is full, skipping job: {job}")
future.set_result(False)
continue
future.set_result(True)
asyncio.create_task(self.controller.launch_scheduler_item(job))
self._scheduler_jobs_queue.task_done()
self._logger.info(f"Launched scheduler job: {job}")

Check warning on line 192 in wandb/sdk/launch/agent2/controllers/local_process.py

View check run for this annotation

Codecov / codecov/patch

wandb/sdk/launch/agent2/controllers/local_process.py#L179-L192

Added lines #L179 - L192 were not covered by tests
7 changes: 7 additions & 0 deletions wandb/sdk/launch/agent2/jobset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ class Job:
claimed_by: str


@dataclass
class JobWithQueue:
job: Job
queue: str
entity: str


def run_queue_item_to_job(run_queue_item: Dict[str, Any]) -> Job:
return Job(
id=run_queue_item["id"],
Expand Down