Skip to content

Commit

Permalink
Add capability to add custom tags to batch
Browse files Browse the repository at this point in the history
  • Loading branch information
Limess committed Nov 8, 2023
1 parent 6badc1d commit d1199b4
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 31 deletions.
2 changes: 2 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@
# in all Metaflow deployments. Hopefully, some day we can flip the
# default to True.
BATCH_EMIT_TAGS = from_conf("BATCH_EMIT_TAGS", False)
# Default tags to add to AWS Batch jobs. These are in addition to the defaults set when BATCH_EMIT_TAGS is true.
BATCH_DEFAULT_TAGS = from_conf("BATCH_DEFAULT_TAGS", {})

###
# AWS Step Functions configuration
Expand Down
44 changes: 30 additions & 14 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,30 @@
import time

from metaflow import util
from metaflow.plugins.datatools.s3.s3tail import S3Tail
from metaflow.plugins.aws.aws_utils import sanitize_batch_tag
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import (
SERVICE_INTERNAL_URL,
DATATOOLS_S3ROOT,
DATASTORE_SYSROOT_S3,
DEFAULT_METADATA,
SERVICE_HEADERS,
AWS_SECRETS_MANAGER_DEFAULT_REGION,
BATCH_DEFAULT_TAGS,
BATCH_EMIT_TAGS,
CARD_S3ROOT,
S3_ENDPOINT_URL,
DATASTORE_SYSROOT_S3,
DATATOOLS_S3ROOT,
DEFAULT_METADATA,
DEFAULT_SECRETS_BACKEND_TYPE,
AWS_SECRETS_MANAGER_DEFAULT_REGION,
S3_ENDPOINT_URL,
S3_SERVER_SIDE_ENCRYPTION,
SERVICE_HEADERS,
SERVICE_INTERNAL_URL,
)

from metaflow.metaflow_config_funcs import config_values

from metaflow.mflog import (
export_mflog_env_vars,
BASH_SAVE_LOGS,
bash_capture_logs,
export_mflog_env_vars,
tail_logs,
BASH_SAVE_LOGS,
)
from metaflow.plugins.aws.aws_utils import sanitize_batch_tag
from metaflow.plugins.datatools.s3.s3tail import S3Tail

from .batch_client import BatchClient

Expand Down Expand Up @@ -63,7 +62,7 @@ def _command(self, environment, code_package_url, step_name, step_cmds, task_spe
datastore_type="s3",
stdout_path=STDOUT_PATH,
stderr_path=STDERR_PATH,
**task_spec
**task_spec,
)
init_cmds = environment.get_package_commands(code_package_url, "s3")
init_expr = " && ".join(init_cmds)
Expand Down Expand Up @@ -186,6 +185,7 @@ def create_job(
attrs={},
host_volumes=None,
use_tmpfs=None,
tags=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
Expand Down Expand Up @@ -317,6 +317,20 @@ def create_job(
if key in attrs:
k, v = sanitize_batch_tag(key, attrs.get(key))
job.tag(k, v)

if not isinstance(BATCH_DEFAULT_TAGS, dict):
raise BatchException(
"The BATCH_DEFAULT_TAGS config option must be a dictionary of key-value tags."
)
for name, value in BATCH_DEFAULT_TAGS.items():
job.tag(name, value)

# add custom tags last to allow override of defaults
if tags is not None:
if not isinstance(tags, dict):
raise BatchException("tags must be a dictionary of key-value tags.")
for name, value in tags.items():
job.tag(name, value)
return job

def launch_job(
Expand All @@ -342,6 +356,7 @@ def launch_job(
efa=None,
host_volumes=None,
use_tmpfs=None,
tags=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
Expand Down Expand Up @@ -380,6 +395,7 @@ def launch_job(
attrs=attrs,
host_volumes=host_volumes,
use_tmpfs=use_tmpfs,
tags=tags,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
Expand Down
14 changes: 8 additions & 6 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from metaflow._vendor import click
import os
import sys
import time
import traceback

from metaflow import util
from metaflow import R
from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY
from metaflow import R, util
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.mflog import TASK_LOG_SOURCE
Expand Down Expand Up @@ -146,6 +145,7 @@ def kill(ctx, run_id, user, my_runs):
help="Activate designated number of elastic fabric adapter devices. "
"EFA driver must be installed and instance type compatible with EFA",
)
@click.option("--aws-tags", multiple=True, default=None, help="AWS tags.")
@click.option("--use-tmpfs", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-tempdir", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-size", help="tmpfs requirement for AWS Batch.")
Expand Down Expand Up @@ -179,13 +179,14 @@ def step(
swappiness=None,
inferentia=None,
efa=None,
aws_tags=None,
use_tmpfs=None,
tmpfs_tempdir=None,
tmpfs_size=None,
tmpfs_path=None,
host_volumes=None,
num_parallel=None,
**kwargs
**kwargs,
):
def echo(msg, stream="stderr", batch_id=None, **kwargs):
msg = util.to_unicode(msg)
Expand Down Expand Up @@ -311,12 +312,13 @@ def _sync_metadata():
attrs=attrs,
host_volumes=host_volumes,
use_tmpfs=use_tmpfs,
tags=aws_tags,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
num_parallel=num_parallel,
)
except Exception as e:
except Exception:
traceback.print_exc()
_sync_metadata()
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
Expand Down
24 changes: 13 additions & 11 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,32 @@
import os
import sys
import platform
import requests
import sys
import time

from metaflow import util
from metaflow import R, current
import requests

from metaflow import R, current
from metaflow.decorators import StepDecorator
from metaflow.plugins.resources_decorator import ResourcesDecorator
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
from metaflow.metadata import MetaDatum
from metaflow.metadata.util import sync_local_metadata_to_datastore
from metaflow.metaflow_config import (
ECS_S3_ACCESS_IAM_ROLE,
BATCH_JOB_QUEUE,
BATCH_CONTAINER_IMAGE,
BATCH_CONTAINER_REGISTRY,
ECS_FARGATE_EXECUTION_ROLE,
BATCH_JOB_QUEUE,
DATASTORE_LOCAL_DIR,
ECS_FARGATE_EXECUTION_ROLE,
ECS_S3_ACCESS_IAM_ROLE,
)
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
from metaflow.sidecar import Sidecar
from metaflow.unbounded_foreach import UBF_CONTROL

from .batch import BatchException
from ..aws_utils import (
compute_resource_attributes,
get_docker_registry,
get_ec2_instance_metadata,
)
from .batch import BatchException


class BatchDecorator(StepDecorator):
Expand Down Expand Up @@ -73,6 +71,9 @@ class BatchDecorator(StepDecorator):
aggressively. Accepted values are whole numbers between 0 and 100.
use_tmpfs: bool, default: False
This enables an explicit tmpfs mount for this step.
tags: map, optional
Sets arbitrary AWS tags on the AWS Batch compute environment.
Set as string key-value pairs.
tmpfs_tempdir: bool, default: True
sets METAFLOW_TEMPDIR to tmpfs_path if set for this step.
tmpfs_size: int, optional
Expand Down Expand Up @@ -103,6 +104,7 @@ class BatchDecorator(StepDecorator):
"efa": None,
"host_volumes": None,
"use_tmpfs": False,
"tags": None,
"tmpfs_tempdir": True,
"tmpfs_size": None,
"tmpfs_path": "/metaflow_temp",
Expand Down Expand Up @@ -346,7 +348,7 @@ def _wait_for_mapper_tasks(self, flow, step_name):
len(flow._control_mapper_tasks),
)
)
except Exception as e:
except Exception:
pass
raise Exception(
"Batch secondary workers did not finish in %s seconds" % TIMEOUT
Expand Down

0 comments on commit d1199b4

Please sign in to comment.