From 0d09a760614ee0bdace99dbb987981b51fdccb1e Mon Sep 17 00:00:00 2001 From: kmonte Date: Fri, 29 Mar 2024 13:21:03 -0700 Subject: [PATCH] Add pipeline_start_post_processing to Orchestrator env PiperOrigin-RevId: 620324534 --- tfx/orchestration/experimental/core/env.py | 13 +++++++++++++ tfx/orchestration/experimental/core/env_test.py | 4 ++++ 2 files changed, 17 insertions(+) diff --git a/tfx/orchestration/experimental/core/env.py b/tfx/orchestration/experimental/core/env.py index 41407bd6a0..5cd9f69714 100644 --- a/tfx/orchestration/experimental/core/env.py +++ b/tfx/orchestration/experimental/core/env.py @@ -71,6 +71,16 @@ def set_health_status(self, status: status_lib.Status) -> None: def check_if_can_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> None: """Check if this orchestrator is capable of orchestrating the pipeline.""" + @abc.abstractmethod + def pipeline_start_post_process(self, pipeline: pipeline_pb2.Pipeline): + """Method for processing a pipeline at the end of its initialization, before it starts running. + + This *will* mutate the provided IR in-place. + + Args: + pipeline: The pipeline IR to process. + """ + class _DefaultEnv(Env): """Default environment.""" @@ -104,6 +114,9 @@ def set_health_status(self, status: status_lib.Status) -> None: def check_if_can_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> None: pass + def pipeline_start_post_process(self, pipeline: pipeline_pb2.Pipeline): + pass + _ENV = _DefaultEnv() diff --git a/tfx/orchestration/experimental/core/env_test.py b/tfx/orchestration/experimental/core/env_test.py index 4dce1b191c..9e04673eb6 100644 --- a/tfx/orchestration/experimental/core/env_test.py +++ b/tfx/orchestration/experimental/core/env_test.py @@ -16,6 +16,7 @@ import tensorflow as tf from tfx.orchestration.experimental.core import env from tfx.orchestration.experimental.core import test_utils +from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import status as status_lib @@ -45,6 +46,9 @@ def set_health_status(self, status: status_lib.Status) -> None: def check_if_can_orchestrate(self, pipeline) -> None: raise NotImplementedError() + def pipeline_start_post_process(self, pipeline: pipeline_pb2.Pipeline): + raise NotImplementedError() + class EnvTest(test_utils.TfxTest):