Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed May 6, 2024
1 parent 5f56c4a commit ea05389
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3418,7 +3418,7 @@ def broadcast_apply_select_indices(
new_partitions, index=new_index, columns=new_columns
)

def construct_dtype(dtype: str, backend: Optional[str]):
def construct_dtype(self, dtype: str, backend: Optional[str]):
if backend is None:
return pandas.api.types.pandas_dtype(dtype)
elif backend == "pyarrow":
Expand Down
14 changes: 13 additions & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import abc
import warnings
from typing import Callable, Hashable, List, Optional
from typing import TYPE_CHECKING, Callable, Hashable, List, Optional

import numpy as np
import pandas
Expand Down Expand Up @@ -52,6 +52,10 @@

from . import doc_utils

if TYPE_CHECKING:
# TODO: should be ModinDataframe
from modin.core.dataframe.pandas.dataframe.dataframe import PandasDataframe


def _get_axis(axis):
"""
Expand Down Expand Up @@ -126,6 +130,8 @@ class BaseQueryCompiler(
for a list of requirements for subclassing this object.
"""

_modin_frame: PandasDataframe

def __wrap_in_qc(self, obj):
"""
Wrap `obj` in query compiler.
Expand Down Expand Up @@ -6747,6 +6753,12 @@ def case_when(self, caselist): # noqa: PR01, RT01, D200
]
return SeriesDefault.register(pandas.Series.case_when)(self, caselist=caselist)

def construct_dtype(self, dtype: str, backend: Optional[str]):
return self._modin_frame.construct_dtype(dtype, backend)

def get_backend(self) -> str:
return self._modin_frame._pandas_backend

def repartition(self, axis=None):
"""
Repartitioning QueryCompiler objects to get ideal partitions inside.
Expand Down
6 changes: 4 additions & 2 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,8 @@ def prod(
skipna is not False
and numeric_only is False
and min_count > len(axis_to_apply)
# Type inference is not so simple for pyarrow
and self._query_compiler.get_backend() == "default"
):
new_index = self.columns if not axis else self.index
# >>> pd.DataFrame([1,2,3,4], dtype="int64[pyarrow]").prod(min_count=10)
Expand All @@ -1630,7 +1632,6 @@ def prod(
return Series(
[np.nan] * len(new_index),
index=new_index,
# TODO: pyarrow backend?
dtype=pandas.api.types.pandas_dtype("float64"),
)

Expand Down Expand Up @@ -2151,12 +2152,13 @@ def sum(
skipna is not False
and numeric_only is False
and min_count > len(axis_to_apply)
# Type inference is not so simple for pyarrow
and self._query_compiler.get_backend() == "default"
):
new_index = self.columns if not axis else self.index
return Series(
[np.nan] * len(new_index),
index=new_index,
# TODO: pyarrow backend?
dtype=pandas.api.types.pandas_dtype("float64"),
)

Expand Down

0 comments on commit ea05389

Please sign in to comment.