Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK] Gen meta before update signatures #3195

Merged
merged 4 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Union

from promptflow._constants import SystemMetricKeys
from promptflow._proxy import ProxyFactory
from promptflow._sdk._constants import REMOTE_URI_PREFIX, ContextAttributeKey, FlowRunProperties
from promptflow._sdk.entities._flows import Flow, Prompty
from promptflow._sdk.entities._run import Run
Expand Down Expand Up @@ -108,7 +107,9 @@ def _run_bulk(self, run: Run, stream=False, **kwargs):
local_storage = LocalStorageOperations(run, stream=stream, run_mode=RunMode.Batch)
with local_storage.logger:
flow_obj = load_flow(source=run.flow)
with flow_overwrite_context(flow_obj, tuning_node, variant, connections=run.connections) as flow:
with flow_overwrite_context(
flow_obj, tuning_node, variant, connections=run.connections, init_kwargs=run.init
) as flow:
self._submit_bulk_run(flow=flow, run=run, local_storage=local_storage)

@classmethod
Expand All @@ -122,12 +123,6 @@ def _submit_bulk_run(
) -> dict:
logger.info(f"Submitting run {run.name}, log path: {local_storage.logger.file_path}")
run_id = run.name
# TODO: unify the logic for prompty and other flows
if not isinstance(flow, Prompty):
# variants are resolved in the context, so we can't move this logic to Operations for now
ProxyFactory().create_inspector_proxy(flow.language).prepare_metadata(
flow_file=Path(flow.path), working_dir=Path(flow.code), init_kwargs=run.init
)

with _change_working_dir(flow.code):
# resolve connections with environment variables overrides to avoid getting unused connections
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ def target_node(self) -> Optional[str]:
return self._target_node

@contextlib.contextmanager
def _resolve_variant(self):
# TODO(2901096): validate invalid configs like variant & connections
# no variant overwrite for eager flow
# no connection overwrite for eager flow
def _resolve_variant(self, init_kwargs=None):
if self.flow_context.variant:
tuning_node, node_variant = parse_variant(self.flow_context.variant)
else:
Expand All @@ -150,6 +147,7 @@ def _resolve_variant(self):
variant=node_variant,
connections=self.flow_context.connections,
overrides=self.flow_context.overrides,
init_kwargs=init_kwargs,
) as temp_flow:
# TODO execute flow test in a separate process.

Expand Down Expand Up @@ -222,18 +220,9 @@ def init(
:return: TestSubmitter instance.
:rtype: TestSubmitter
"""
with self._resolve_variant():
with self._resolve_variant(init_kwargs=init_kwargs):
# temp flow is generated, will use self.flow instead of self._origin_flow in the following context
self._within_init_context = True

if not isinstance(self.flow, Prompty):
# variant is resolve in the context, so we can't move this to Operations for now
ProxyFactory().create_inspector_proxy(self.flow.language).prepare_metadata(
flow_file=self.flow.path,
working_dir=self.flow.code,
init_kwargs=init_kwargs,
)

self._target_node = target_node
self._enable_stream_output = stream_output

Expand Down
15 changes: 15 additions & 0 deletions src/promptflow-devkit/promptflow/_sdk/_orchestrator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,21 @@ def override_flow_yaml(
*,
overrides: dict = None,
drop_node_variants: bool = False,
init_kwargs: dict = None,
):
# generate meta before updating signatures since update signatures requires it.
if not isinstance(flow, Prompty):
D-W- marked this conversation as resolved.
Show resolved Hide resolved
ProxyFactory().create_inspector_proxy(flow.language).prepare_metadata(
flow_file=Path(flow.path), working_dir=Path(flow.code), init_kwargs=init_kwargs
)
if isinstance(flow, FlexFlow):
# update signatures for flex flow
# no variant overwrite for eager flow
for param in [tuning_node, variant, connections, overrides]:
if param:
logger.warning(
"Eager flow does not support tuning node, variant, connection override. " f"Dropping params {param}"
)
update_signatures(code=flow_dir_path, data=flow_dag)
else:
# always overwrite variant since we need to overwrite default variant if not specified.
Expand All @@ -233,6 +245,7 @@ def flow_overwrite_context(
*,
overrides: dict = None,
drop_node_variants: bool = False,
init_kwargs: dict = None,
):
"""Override variant and connections in the flow."""
flow_dag = flow._data
Expand All @@ -252,6 +265,7 @@ def flow_overwrite_context(
connections=connections,
overrides=overrides,
drop_node_variants=drop_node_variants,
init_kwargs=init_kwargs,
)
flow_dag.pop("additional_includes", None)
dump_flow_dag_according_to_content(flow_dag=flow_dag, flow_path=Path(temp_dir))
Expand All @@ -270,6 +284,7 @@ def flow_overwrite_context(
connections=connections,
overrides=overrides,
drop_node_variants=drop_node_variants,
init_kwargs=init_kwargs,
)
flow_path = dump_flow_dag_according_to_content(flow_dag=flow_dag, flow_path=Path(temp_dir))
if isinstance(flow, FlexFlow):
Expand Down