Skip to content

Commit

Permalink
Clean up typing of function signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Jul 7, 2023
1 parent 5a28b40 commit 69810f9
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions ludwig/data/dataframe/daft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Dict, List
from typing import Callable, Dict, List, overload, Union

import daft
import pandas as pd
Expand All @@ -11,6 +11,10 @@
logger = logging.getLogger(__name__)


DataFrameMapFn = Callable[[pd.DataFrame], pd.DataFrame]
SeriesMapFn = Callable[[pd.Series], pd.Series]


class LudwigDaftDataframe:
"""Shim layer on top of a daft.Dataframe to make it behave more like a Pandas Dataframe object.
Expand Down Expand Up @@ -129,9 +133,22 @@ def map_objects(self, series: LudwigDaftSeries, map_fn: Callable[[object], objec
# can be much more optimized in terms of memory usage
return LudwigDaftSeries(series.expr.apply(map_fn, return_dtype=daft.DataType.python()))

def map_partitions(self, obj: LudwigDaftSeries, map_fn: Callable[[pd.Series], pd.Series], meta=None):
# NOTE: Although the function signature indicates that this function takes in a Series, in practice
# it appears that this function is often used interchangeably to run on both Series and DataFrames
# NOTE: Although the base class' function signature indicates that this function takes in a Series, in practice
# it appears that this function is often used interchangeably to run functions on both Series and DataFrames
@overload
def map_partitions(self, obj: LudwigDaftDataframe, map_fn: DataFrameMapFn):
...

@overload
def map_partitions(self, obj: LudwigDaftSeries, map_fn: SeriesMapFn, meta=None):
...

def map_partitions(
self,
obj: Union[LudwigDaftSeries, LudwigDaftDataframe],
map_fn: Union[DataFrameMapFn, SeriesMapFn],
meta=None,
):
if isinstance(obj, LudwigDaftDataframe):
raise NotImplementedError("TODO: Implementation")
elif isinstance(obj, LudwigDaftSeries):
Expand Down

0 comments on commit 69810f9

Please sign in to comment.