Skip to content

Commit

Permalink
FEAT-#6990: Implement lazy execution for the Ray virtual partitions.
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Mar 16, 2024
1 parent c753436 commit 2e3390b
Show file tree
Hide file tree
Showing 7 changed files with 560 additions and 318 deletions.
194 changes: 154 additions & 40 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@

from modin.core.execution.ray.common import MaterializationHook, RayWrapper
from modin.logging import get_logger
from modin.utils import _inherit_docstrings

ObjectRefType = Union[ray.ObjectRef, ClientObjectRef, None]
ObjectRefType = Union[ray.ObjectRef, ClientObjectRef]
ObjectRefOrListType = Union[ObjectRefType, List[ObjectRefType]]
ListOrTuple = (list, tuple)

Expand Down Expand Up @@ -68,16 +69,18 @@ class DeferredExecution:
Attributes
----------
data : ObjectRefType or DeferredExecution
data : object
The execution input.
func : callable or ObjectRefType
A function to be executed.
args : list or tuple
args : list or tuple, optional
Additional positional arguments to be passed in `func`.
kwargs : dict
kwargs : dict, optional
Additional keyword arguments to be passed in `func`.
num_returns : int
num_returns : int, default: 1
The number of the return values.
flat_data : bool
True means that the data is neither DeferredExecution nor list.
flat_args : bool
True means that there are no lists or DeferredExecution objects in `args`.
In this case, no arguments processing is performed and `args` is passed
Expand All @@ -88,26 +91,29 @@ class DeferredExecution:

def __init__(
self,
data: Union[
ObjectRefType,
"DeferredExecution",
List[Union[ObjectRefType, "DeferredExecution"]],
],
data: Any,
func: Union[Callable, ObjectRefType],
args: Union[List[Any], Tuple[Any]],
kwargs: Dict[str, Any],
args: Union[List[Any], Tuple[Any]] = None,
kwargs: Dict[str, Any] = None,
num_returns=1,
):
if isinstance(data, DeferredExecution):
data.subscribe()
self.flat_data = self._flat_args((data,))
self.data = data
self.func = func
self.args = args
self.kwargs = kwargs
self.num_returns = num_returns
self.flat_args = self._flat_args(args)
self.flat_kwargs = self._flat_args(kwargs.values())
self.subscribers = 0
if args is not None:
self.args = args
self.flat_args = self._flat_args(args)
else:
self.args = ()
self.flat_args = True
if kwargs is not None:
self.kwargs = kwargs
self.flat_kwargs = self._flat_args(kwargs.values())
else:
self.kwargs = {}
self.flat_kwargs = True

@classmethod
def _flat_args(cls, args: Iterable):
Expand All @@ -134,7 +140,7 @@ def _flat_args(cls, args: Iterable):

def exec(
self,
) -> Tuple[ObjectRefOrListType, Union["MetaList", List], Union[int, List[int]]]:
) -> Tuple[ObjectRefOrListType, "MetaList", Union[int, List[int]]]:
"""
Execute this task, if required.
Expand All @@ -150,11 +156,29 @@ def exec(
return self.data, self.meta, self.meta_offset

if (
not isinstance(self.data, DeferredExecution)
self.flat_data
and self.flat_args
and self.flat_kwargs
and self.num_returns == 1
):
# self.data = RayWrapper.materialize(self.data)
# self.args = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in self.args
# ]
# self.kwargs = {
# k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for k, o in self.kwargs.items()
# }
# obj = _REMOTE_EXEC.exec_func(
# RayWrapper.materialize(self.func), self.data, self.args, self.kwargs
# )
# result, length, width, ip = (
# obj,
# len(obj) if hasattr(obj, "__len__") else 0,
# len(obj.columns) if hasattr(obj, "columns") else 0,
# "",
# )
result, length, width, ip = remote_exec_func.remote(
self.func, self.data, *self.args, **self.kwargs
)
Expand All @@ -166,19 +190,28 @@ def exec(
# it back. After the execution, the result is saved and the counter has no effect.
self.subscribers += 2
consumers, output = self._deconstruct()

# assert not any(isinstance(o, ListOrTuple) for o in output)
# tmp = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in output
# ]
# list(_REMOTE_EXEC.construct(tmp))

# The last result is the MetaList, so adding +1 here.
num_returns = sum(c.num_returns for c in consumers) + 1
results = self._remote_exec_chain(num_returns, *output)
meta = MetaList(results.pop())
meta_offset = 0
results = iter(results)
for de in consumers:
if de.num_returns == 1:
num_returns = de.num_returns
if num_returns == 1:
de._set_result(next(results), meta, meta_offset)
meta_offset += 2
else:
res = list(islice(results, num_returns))
offsets = list(range(0, 2 * num_returns, 2))
offsets = list(range(meta_offset, meta_offset + 2 * num_returns, 2))
de._set_result(res, meta, offsets)
meta_offset += 2 * num_returns
return self.data, self.meta, self.meta_offset
Expand Down Expand Up @@ -318,6 +351,7 @@ def _deconstruct_chain(
break
elif not isinstance(data := de.data, DeferredExecution):
if isinstance(data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
data, output, stack, result_consumers, out_append
)
Expand Down Expand Up @@ -394,7 +428,13 @@ def _deconstruct_list(
if out_pos := getattr(obj, "out_pos", None):
obj.unsubscribe()
if obj.has_result:
out_append(obj.data)
if isinstance(obj.data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(

Check warning on line 433 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L432-L433

Added lines #L432 - L433 were not covered by tests
obj.data, output, stack, result_consumers, out_append
)
else:
out_append(obj.data)
else:
out_append(_Tag.REF)
out_append(out_pos)
Expand Down Expand Up @@ -432,13 +472,13 @@ def _remote_exec_chain(num_returns: int, *args: Tuple) -> List[Any]:
list
The execution results. The last element of this list is the ``MetaList``.
"""
# Prefer _remote_exec_single_chain(). It has fewer arguments and
# does not require the num_returns to be specified in options.
# Prefer _remote_exec_single_chain(). It does not require the num_returns
# to be specified in options.
if num_returns == 2:
return _remote_exec_single_chain.remote(*args)
else:
return _remote_exec_multi_chain.options(num_returns=num_returns).remote(
num_returns, *args
*args
)

def _set_result(
Expand All @@ -456,7 +496,7 @@ def _set_result(
meta : MetaList
meta_offset : int or list of int
"""
del self.func, self.args, self.kwargs, self.flat_args, self.flat_kwargs
del self.func, self.args, self.kwargs
self.data = result
self.meta = meta
self.meta_offset = meta_offset
Expand All @@ -466,6 +506,78 @@ def __reduce__(self):
raise NotImplementedError("DeferredExecution is not serializable!")


ObjectRefOrDeType = Union[ObjectRefType, DeferredExecution]


class DeferredGetItem(DeferredExecution):
"""
Deferred execution task that returns an item at the specified index.
Parameters
----------
data : ObjectRefOrDeType
The object to get the item from.
idx : int
The item index.
"""

def __init__(self, data: ObjectRefOrDeType, idx: int):
super().__init__(data, self._remote_fn(), [idx])
self.index = idx

@_inherit_docstrings(DeferredExecution.exec)
def exec(self) -> Tuple[ObjectRefType, "MetaList", int]:
if self.has_result:
return self.data, self.meta, self.meta_offset

if not isinstance(self.data, DeferredExecution) or self.data.num_returns == 1:
return super().exec()

Check warning on line 534 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L534

Added line #L534 was not covered by tests

# If `data` is a `DeferredExecution`, that returns multiple results,
# it's not required to execute `_remote_fn()`. We can only execute
# `data` and get the result by index.
self._data_exec()
return self.data, self.meta, self.meta_offset

@property
@_inherit_docstrings(DeferredExecution.has_result)
def has_result(self):
if super().has_result:
return True

if (
isinstance(self.data, DeferredExecution)
and self.data.has_result
and self.data.num_returns != 1
):
self._data_exec()
return True

return False

def _data_exec(self):
"""Execute the `data` task and get the result."""
obj, meta, offsets = self.data.exec()
self._set_result(obj[self.index], meta, offsets[self.index])

@classmethod
def _remote_fn(cls) -> ObjectRefType:
"""
Return the remote function reference.
Returns
-------
ObjectRefType
"""
if (fn := getattr(cls, "_GET_ITEM", None)) is None:

def get_item(obj, index): # pragma: no cover
return obj[index]

cls._GET_ITEM = fn = RayWrapper.put(get_item)
return fn


class MetaList:
"""
Meta information, containing the result lengths and the worker address.
Expand All @@ -478,6 +590,10 @@ class MetaList:
def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]):
self._obj = obj

def materialize(self):
"""Materialized the list, if required."""
self._obj = RayWrapper.materialize(self._obj)

Check warning on line 595 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L595

Added line #L595 was not covered by tests

def __getitem__(self, index):
"""
Get item at the specified index.
Expand Down Expand Up @@ -508,7 +624,7 @@ def __setitem__(self, index, value):
obj[index] = value


class MetaListHook(MaterializationHook):
class MetaListHook(MaterializationHook, DeferredGetItem):
"""
Used by MetaList.__getitem__() for lazy materialization and getting a single value from the list.
Expand All @@ -521,6 +637,7 @@ class MetaListHook(MaterializationHook):
"""

def __init__(self, meta: MetaList, idx: int):
super().__init__(meta._obj, idx)
self.meta = meta
self.idx = idx

Expand Down Expand Up @@ -605,7 +722,7 @@ def exec_func(fn: Callable, obj: Any, args: Tuple, kwargs: Dict) -> Any:
raise err

@classmethod
def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
def construct(cls, args: Tuple): # pragma: no cover
"""
Construct and execute the specified chain.
Expand All @@ -615,7 +732,6 @@ def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
Parameters
----------
num_returns : int
args : tuple
Yields
Expand Down Expand Up @@ -687,7 +803,7 @@ def construct_chain(

while chain:
fn = pop()
if fn == tg_e:
if fn is tg_e:
lst.append(obj)
break

Expand Down Expand Up @@ -717,10 +833,10 @@ def construct_chain(

itr = iter([obj] if num_returns == 1 else obj)
for _ in range(num_returns):
obj = next(itr)
meta.append(len(obj) if hasattr(obj, "__len__") else 0)
meta.append(len(obj.columns) if hasattr(obj, "columns") else 0)
yield obj
o = next(itr)
meta.append(len(o) if hasattr(o, "__len__") else 0)
meta.append(len(o.columns) if hasattr(o, "columns") else 0)
yield o

@classmethod
def construct_list(
Expand Down Expand Up @@ -834,20 +950,18 @@ def _remote_exec_single_chain(
-------
Generator
"""
return remote_executor.construct(num_returns=2, args=args)
return remote_executor.construct(args=args)


@ray.remote
def _remote_exec_multi_chain(
num_returns: int, *args: Tuple, remote_executor=_REMOTE_EXEC
*args: Tuple, remote_executor=_REMOTE_EXEC
) -> Generator: # pragma: no cover
"""
Execute the deconstructed chain with a multiple return values in a worker process.
Parameters
----------
num_returns : int
The number of return values.
*args : tuple
A deconstructed chain to be executed.
remote_executor : _RemoteExecutor, default: _REMOTE_EXEC
Expand All @@ -857,4 +971,4 @@ def _remote_exec_multi_chain(
-------
Generator
"""
return remote_executor.construct(num_returns, args)
return remote_executor.construct(args)
4 changes: 2 additions & 2 deletions modin/core/execution/ray/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import asyncio
import os
from types import FunctionType
from typing import Sequence
from typing import Iterable, Sequence

import ray
from ray.util.client.common import ClientObjectRef
Expand Down Expand Up @@ -214,7 +214,7 @@ def wait(cls, obj_ids, num_returns=None):
num_returns : int, optional
"""
if not isinstance(obj_ids, Sequence):
obj_ids = list(obj_ids)
obj_ids = list(obj_ids) if isinstance(obj_ids, Iterable) else [obj_ids]

ids = set()
for obj in obj_ids:
Expand Down
Loading

0 comments on commit 2e3390b

Please sign in to comment.