Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

noop #6706

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

noop #6706

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
| `packaging` | `>=20,<21` | `>=22` | |
| `attrs` | `19.3.0,<22` | `19.3.0,<24` | |
| `google-cloud-bigquery` | `>=2.26.0,<3` | `>=3,<4` | |
| `tensorflow` | `>=2.15,<2.16` | `>=2.13,<2.14` | |
| `tensorflow-hub` | `>=0.9.0,<0.14` | `>=0.15.0,<0.16` | |
## Documentation Updates

Expand Down
2 changes: 1 addition & 1 deletion tfx/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def make_required_install_packages():
# Pip might stuck in a TF 1.15 dependency although there is a working
# dependency set with TF 2.x without the sync.
# pylint: disable=line-too-long
'tensorflow' + select_constraint('>=2.15.0,<2.16'),
'tensorflow' + select_constraint('>=2.13.0,<2.14'),
# pylint: enable=line-too-long
'tensorflow-hub>=0.15.0,<0.16',
'tensorflow-data-validation'
Expand Down
10 changes: 8 additions & 2 deletions tfx/dsl/component/experimental/component_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,20 @@ def _type_check_execution_function_params(
channel = channel_parameters[param_name]
allowed_param_types = []

if not issubclass(channel.type, standard_artifacts.ValueArtifact):
if param_name in spec.OUTPUTS or not issubclass(
channel.type, standard_artifacts.ValueArtifact
):
# For output channels, pass through the allowed type
# For input channels, pass through non-ValueArtifact types. We handle
# ValueArtifact types in the elif branch, because we want to allow users
# to specify the associated primitive types (e.g. str) instead.
allowed_param_types = [
list[channel.type],
Optional[channel.type] if channel.optional else channel.type,
]
elif param_name in spec.INPUTS:
if channel.type in _VALUE_ARTIFACT_TO_TYPE:
# Primitvie ValueArtifact input can be annotated as a primitive type.
# Primitive ValueArtifact input can be annotated as a primitive type.
primitive_type = _VALUE_ARTIFACT_TO_TYPE[channel.type]
allowed_param_types.append(
Optional[primitive_type] if channel.optional else primitive_type
Expand Down
22 changes: 11 additions & 11 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,6 @@ def _load_reused_pipeline_view(
mlmd_handle=mlmd_handle,
pipeline_id=pipeline_uid.pipeline_id,
pipeline_run_id=base_run_id,
# If current pipeline run is allowed and base_run_id is not specified,
# reuse the most recent completed run.
non_active_only=env.get_env().concurrent_pipeline_runs_enabled(),
)
except status_lib.StatusNotOkError as e:
Expand All @@ -842,10 +840,9 @@ def _load_reused_pipeline_view(

if execution_lib.is_execution_active(reused_pipeline_view.execution):
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
code=status_lib.Code.ALREADY_EXISTS,
message=(
'The previous pipeline run'
f' {reused_pipeline_view.pipeline_run_id} is still active.'
f'An active pipeline is already running with uid {pipeline_uid}.'
),
)

Expand All @@ -856,6 +853,7 @@ def _load_reused_pipeline_view(
def resume_pipeline(
mlmd_handle: metadata.Metadata,
pipeline: pipeline_pb2.Pipeline,
pipeline_id: str,
run_id: Optional[str] = None,
) -> pstate.PipelineState:
"""Resumes a pipeline run from previously failed nodes.
Expand All @@ -865,6 +863,7 @@ def resume_pipeline(
Args:
mlmd_handle: A handle to the MLMD db.
pipeline: IR of the pipeline to resume.
pipeline_id: The id (name) of the pipeline to resume.
run_id: the run_id of the pipeline run to resume.

Returns:
Expand All @@ -877,6 +876,9 @@ def resume_pipeline(
is not found for resuming. With code 'INVALID_ARGUMENT' if concurrent
pipeline runs are enabled but pipeline run id is missing.
"""
reuse_pipeline_uid = task_lib.PipelineUid.from_pipeline_id_and_run_id(
pipeline_id, run_id
)
logging.info(
'Received request to resume pipeline; pipeline uid: %s',
task_lib.PipelineUid.from_pipeline(pipeline),
Expand All @@ -892,7 +894,7 @@ def resume_pipeline(

if (
env.get_env().concurrent_pipeline_runs_enabled()
and not run_id
and not reuse_pipeline_uid.pipeline_run_id
):
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
Expand All @@ -902,10 +904,10 @@ def resume_pipeline(
),
)

if run_id:
if reuse_pipeline_uid.pipeline_run_id:
snapshot_settings = pipeline_pb2.SnapshotSettings()
partial_run_utils.set_base_pipeline_run_strategy(
snapshot_settings, run_id
snapshot_settings, reuse_pipeline_uid.pipeline_run_id
)
else:
snapshot_settings = partial_run_utils.latest_pipeline_snapshot_settings()
Expand Down Expand Up @@ -948,9 +950,7 @@ def resume_pipeline(
)
if pipeline.runtime_spec.HasField('snapshot_settings'):
try:
partial_run_utils.snapshot(
mlmd_handle, pipeline, latest_pipeline_view.pipeline_run_id
)
partial_run_utils.snapshot(mlmd_handle, pipeline)
except ValueError as e:
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT, message=str(e)
Expand Down
41 changes: 22 additions & 19 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def test_initiate_pipeline_start(self, pipeline):
@mock.patch.object(partial_run_utils, 'snapshot')
def test_resume_pipeline(self, mock_snapshot):
with self._mlmd_connection as m:
pipeline = _test_pipeline(
'test_pipeline', pipeline_pb2.Pipeline.SYNC, pipeline_run_id='run0'
)
pipeline = _test_pipeline('test_pipeline', pipeline_pb2.Pipeline.SYNC)
pipeline_id = pipeline.pipeline_info.id
pipeline_run_id = 'run1'
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
node_example_gen = pipeline.nodes.add().pipeline_node
node_example_gen.node_info.id = 'ExampleGen'
Expand All @@ -176,47 +176,48 @@ def test_resume_pipeline(self, mock_snapshot):
# Error if attempt to resume the pipeline when there is no previous run.
with self.assertRaises(status_lib.StatusNotOkError) as exception_context:
pipeline_ops.resume_pipeline(
m, pipeline, run_id='run0'
m, pipeline, pipeline_id=pipeline_id, run_id=pipeline_run_id
)
self.assertEqual(
status_lib.Code.NOT_FOUND, exception_context.exception.code
)

# Initiate a pipeline start.
pipeline_state_run0 = pipeline_ops.initiate_pipeline_start(m, pipeline)
pipeline_state_run1 = pipeline_ops.initiate_pipeline_start(m, pipeline)

# Error if attempt to resume the pipeline when the previous one is active.
run_id = 'run1'
pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run1'
with self.assertRaises(status_lib.StatusNotOkError) as exception_context:
pipeline_ops.resume_pipeline(
m, pipeline, run_id='run0'
m, pipeline, pipeline_id=pipeline_id, run_id=pipeline_run_id
)
self.assertEqual(
status_lib.Code.FAILED_PRECONDITION, exception_context.exception.code
status_lib.Code.ALREADY_EXISTS, exception_context.exception.code
)

with pipeline_state_run0:
with pipeline_state_run1:
example_gen_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')
with pipeline_state_run0.node_state_update_context(
with pipeline_state_run1.node_state_update_context(
example_gen_node_uid
) as node_state:
node_state.update(pstate.NodeState.COMPLETE)
with pipeline_state_run0.node_state_update_context(
with pipeline_state_run1.node_state_update_context(
trainer_node_uid
) as node_state:
node_state.update(pstate.NodeState.FAILED)
pipeline_state_run0.set_pipeline_execution_state(
pipeline_state_run1.set_pipeline_execution_state(
metadata_store_pb2.Execution.COMPLETE
)
pipeline_state_run0.initiate_stop(
pipeline_state_run1.initiate_stop(
status_lib.Status(code=status_lib.Code.ABORTED)
)
# Only Trainer is marked to run since ExampleGen succeeded in previous
# run.
expected_pipeline = copy.deepcopy(pipeline)
partial_run_utils.set_base_pipeline_run_strategy(
expected_pipeline.runtime_spec.snapshot_settings, 'run0',
partial_run_utils.set_latest_pipeline_run_strategy(
expected_pipeline.runtime_spec.snapshot_settings
)
expected_pipeline.nodes[
0
Expand All @@ -230,10 +231,10 @@ def test_resume_pipeline(self, mock_snapshot):
1
].pipeline_node.execution_options.run.depends_on_snapshot = True
with pipeline_ops.resume_pipeline(
m, pipeline, run_id='run0'
) as pipeline_state_run1:
self.assertEqual(expected_pipeline, pipeline_state_run1.pipeline)
self.assertTrue(pipeline_state_run1.is_active())
m, pipeline, pipeline_id=pipeline_id, run_id=run_id
) as pipeline_state_run2:
self.assertEqual(expected_pipeline, pipeline_state_run2.pipeline)
pipeline_state_run2.is_active()
mock_snapshot.assert_called_once()

@mock.patch.object(partial_run_utils, 'snapshot')
Expand All @@ -243,6 +244,7 @@ def test_resume_pipeline_when_concurrent_pipeline_runs_enabled(
with test_utils.concurrent_pipeline_runs_enabled_env():
with self._mlmd_connection as m:
pipeline = _test_pipeline('test_pipeline', pipeline_pb2.Pipeline.SYNC)
pipeline_id = pipeline.pipeline_info.id
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
node_example_gen = pipeline.nodes.add().pipeline_node
node_example_gen.node_info.id = 'ExampleGen'
Expand Down Expand Up @@ -307,6 +309,7 @@ def test_resume_pipeline_when_concurrent_pipeline_runs_enabled(
pipeline_ops.resume_pipeline(
m,
pipeline,
pipeline_id=pipeline_id,
)
self.assertEqual(
status_lib.Code.INVALID_ARGUMENT, exception_context.exception.code
Expand All @@ -315,7 +318,7 @@ def test_resume_pipeline_when_concurrent_pipeline_runs_enabled(
# Success if pipeline resumed with run id.
self.assertEqual('run0', pipeline_uid.pipeline_run_id)
with pipeline_ops.resume_pipeline(
m, pipeline, run_id='run0'
m, pipeline, pipeline_id=pipeline_id, run_id='run0'
) as pipeline_state:
pipeline_state.is_active()
mock_snapshot.assert_called_once()
Expand Down
20 changes: 18 additions & 2 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,9 +1237,14 @@ def load(cls,
executions = mlmd_handle.store.get_executions_by_context(
context.id, list_options=list_options, **kwargs
)
if non_active_only:
executions = [
e for e in executions if not execution_lib.is_execution_active(e)
]

if pipeline_run_id is None and executions:
return cls(pipeline_id, context, executions[0])
execution = _get_latest_execution(executions)
return cls(pipeline_id, context, execution)

for execution in executions:
if execution.custom_properties[
Expand All @@ -1249,7 +1254,7 @@ def load(cls,
raise status_lib.StatusNotOkError(
code=status_lib.Code.NOT_FOUND,
message=(
f'No {non_active_msg} pipeline with run_id {pipeline_run_id} found.'
f'No {non_active_msg}pipeline with run_id {pipeline_run_id} found.'
),
)

Expand Down Expand Up @@ -1497,6 +1502,17 @@ def _get_pipeline_from_orchestrator_execution(
return _PipelineIRCodec.get().decode(pipeline_ir)


def _get_latest_execution(
executions: List[metadata_store_pb2.Execution]
) -> metadata_store_pb2.Execution:
"""gets a single latest execution from the executions."""

def _get_creation_time(execution):
return execution.create_time_since_epoch

return max(executions, key=_get_creation_time)


def _get_orchestrator_context(mlmd_handle: metadata.Metadata, pipeline_id: str,
**kwargs) -> metadata_store_pb2.Context:
"""Returns the orchestrator context of a particular pipeline."""
Expand Down