Skip to content

Commit

Permalink
[@parallel on Kubernetes] support for Jobsets
Browse files Browse the repository at this point in the history
Implementation originates from [#1744]

This commit adds support for @parallel when flows are run `--with kubernetes`
Support for Argo workflows will follow in a separate commit.

A user can run a flow with the following:

    @step
    def start(self):
        self.next(self.parallel_step, num_parallel=3)

    @kubernetes(cpu=1, memory=512)
    @parallel
    @step
    def parallel_step(self):
    ...

Testing Done:
- Ran a flow with @parallel on Kubernetes. Verified that it works correctly
- Ran a flow without @parallel on Kubernetes. Verified that it works as expected.
- Verified that jobsets based @parallel step gets scaled down if user kills it with a Ctrl-C

Changes to original Implementation:
- pass down ports of Jobsets
- Ensured that `ubf_context` is set correctly
- Ensured that `split-index` is set correctly based on the type of task (control vs worker)
- Fix bug in incorrect RANK setting. In the earlier implementation, we were setting `parallelism` to created  `replicatedJobs`.
    - In this implementation, we create a different copy of the job for each replicated worker.
    - So retrieving the rank based on the Kubernetes V1EnvVar.valueFrom (metadata.annotations['batch.kubernetes.io/job-completion-index']) wont work.
    - since `job-completion-index` relies on setting `parallelism` on the `job_spec`.
    Instead now we just statically set the `RANK` based on the index in the iterator defining the jobs.
  • Loading branch information
valayDave committed Apr 18, 2024
1 parent 5908c4e commit 302ae6a
Show file tree
Hide file tree
Showing 6 changed files with 612 additions and 210 deletions.
5 changes: 5 additions & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,11 @@ def _compile_workflow_template(self):
# Visit every node and yield the uber DAGTemplate(s).
def _dag_templates(self):
def _visit(node, exit_node=None, templates=None, dag_tasks=None):
if node.parallel_foreach:
raise ArgoWorkflowsException(
"Deploying flows with @parallel decorator(s) "
"as Argo Workflows is not supported currently."
)
# Every for-each node results in a separate subDAG and an equivalent
# DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
# has a unique name - the top-level DAGTemplate is named as the name of
Expand Down
92 changes: 88 additions & 4 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,88 @@ def _command(
return shlex.split('bash -c "%s"' % cmd_str)

def launch_job(self, **kwargs):
self._job = self.create_job(**kwargs).execute()
if (
"num_parallel" in kwargs
and kwargs["num_parallel"]
and int(kwargs["num_parallel"]) > 0
):
job = self.create_job_object(**kwargs)
spec = job.create_job_spec()
# `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
# This will be modified by the KubernetesJobSet object
self._job = self.create_jobset(
spec,
kwargs["flow_name"],
kwargs["run_id"],
kwargs["step_name"],
kwargs["task_id"],
kwargs["attempt"],
kwargs["code_package_url"],
kwargs["step_cli"],
kwargs["namespace"],
kwargs["use_tmpfs"],
kwargs.get("tmpfs_dir", None),
kwargs.get("tmpfs_size", None),
kwargs.get("tmpfs_path", None),
kwargs.get("env", None),
kwargs.get("labels", None),
kwargs["num_parallel"],
kwargs.get("port", None),
).execute()
else:
kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
self._job = self.create_job_object(**kwargs).k8screate().execute()

def create_jobset(
self,
job_spec,
flow_name,
run_id,
step_name,
task_id,
attempt,
code_package_url,
step_cli,
namespace=None,
use_tmpfs=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
env=None,
labels=None,
num_parallel=None,
port=None,
):
if env is None:
env = {}

def create_job(
js = KubernetesClient().jobset(
name="js-%s" % task_id.split("-")[-1],
run_id=run_id,
task_id=task_id,
step_name=step_name,
namespace=namespace,
command=self._command(
flow_name=flow_name,
run_id=run_id,
step_name=step_name,
task_id=task_id,
attempt=attempt,
code_package_url=code_package_url,
step_cmds=[step_cli],
),
labels=self._get_labels(labels),
use_tmpfs=use_tmpfs,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
num_parallel=num_parallel,
job_spec=job_spec,
port=port,
)
return js

def create_job_object(
self,
flow_name,
run_id,
Expand Down Expand Up @@ -176,14 +255,15 @@ def create_job(
labels=None,
shared_memory=None,
port=None,
name_pattern=None,
num_parallel=None,
):
if env is None:
env = {}

job = (
KubernetesClient()
.job(
generate_name="t-{uid}-".format(uid=str(uuid4())[:8]),
generate_name=name_pattern,
namespace=namespace,
service_account=service_account,
secrets=secrets,
Expand Down Expand Up @@ -217,6 +297,7 @@ def create_job(
persistent_volume_claims=persistent_volume_claims,
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
)
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
.environment_variable("METAFLOW_CODE_URL", code_package_url)
Expand Down Expand Up @@ -332,6 +413,9 @@ def create_job(
.label("app.kubernetes.io/part-of", "metaflow")
)

return job

def create_k8sjob(self, job):
return job.create()

def wait(self, stdout_location, stderr_location, echo=None):
Expand Down
12 changes: 12 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
from metaflow.mflog import TASK_LOG_SOURCE
import metaflow.tracing as tracing
Expand Down Expand Up @@ -109,6 +110,15 @@ def kubernetes():
)
@click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
@click.option("--port", default=None, help="Port number to expose from the container")
@click.option(
"--ubf-context", default=None, type=click.Choice([None, UBF_CONTROL, UBF_TASK])
)
@click.option(
"--num-parallel",
default=None,
type=int,
help="Number of parallel nodes to run as a multi-node job.",
)
@click.pass_context
def step(
ctx,
Expand Down Expand Up @@ -136,6 +146,7 @@ def step(
tolerations=None,
shared_memory=None,
port=None,
num_parallel=None,
**kwargs
):
def echo(msg, stream="stderr", job_id=None, **kwargs):
Expand Down Expand Up @@ -251,6 +262,7 @@ def _sync_metadata():
tolerations=tolerations,
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
)
except Exception as e:
traceback.print_exc(chain=False)
Expand Down
5 changes: 4 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from metaflow.exception import MetaflowException

from .kubernetes_job import KubernetesJob
from .kubernetes_job import KubernetesJob, KubernetesJobSet


CLIENT_REFRESH_INTERVAL_SECONDS = 300
Expand Down Expand Up @@ -61,5 +61,8 @@ def get(self):

return self._client

def jobset(self, **kwargs):
return KubernetesJobSet(self, **kwargs)

def job(self, **kwargs):
return KubernetesJob(self, **kwargs)
34 changes: 28 additions & 6 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
from .kubernetes import KubernetesException, parse_kube_keyvalue_list
from metaflow.unbounded_foreach import UBF_CONTROL

try:
unicode
Expand Down Expand Up @@ -227,12 +228,6 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
"Kubernetes. Please use one or the other.".format(step=step)
)

for deco in decos:
if getattr(deco, "IS_PARALLEL", False):
raise KubernetesException(
"@kubernetes does not support parallel execution currently."
)

# Set run time limit for the Kubernetes job.
self.run_time_limit = get_run_time_limit_for_task(decos)
if self.run_time_limit < 60:
Expand Down Expand Up @@ -441,6 +436,25 @@ def task_pre_step(
self._save_logs_sidecar = Sidecar("save_logs_periodically")
self._save_logs_sidecar.start()

num_parallel = None
if hasattr(flow, "_parallel_ubf_iter"):
num_parallel = flow._parallel_ubf_iter.num_parallel
if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL:
control_task_id = current.task_id
top_task_id = control_task_id.replace("control-", "")
mapper_task_ids = [control_task_id] + [
"worker-%s-%d" % (top_task_id, node_idx)
for node_idx in range(1, num_parallel)
]
flow._control_mapper_tasks = [
"%s/%s/%s" % (run_id, step_name, mapper_task_id)
for mapper_task_id in mapper_task_ids
]
flow._control_task_is_mapper_zero = True

if num_parallel and num_parallel > 1:
_setup_multinode_environment()

def task_finished(
self, step_name, flow, graph, is_task_ok, retry_count, max_retries
):
Expand Down Expand Up @@ -474,3 +488,11 @@ def _save_package_once(cls, flow_datastore, package):
cls.package_url, cls.package_sha = flow_datastore.save_data(
[package.blob], len_hint=1
)[0]


def _setup_multinode_environment():
import socket

os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"])
os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"]
os.environ["MF_PARALLEL_NODE_INDEX"] = os.environ["RANK"]

0 comments on commit 302ae6a

Please sign in to comment.