Skip to content

Commit

Permalink
Lazy loading CSV files (or tables in general)
Browse files Browse the repository at this point in the history
Signed-off-by: Bram Stoeller <[email protected]>
  • Loading branch information
bramstoeller committed Mar 1, 2023
1 parent 0351485 commit f7f86f6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
17 changes: 12 additions & 5 deletions src/power_grid_model_io/data_stores/csv_dir_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List

import pandas as pd

Expand All @@ -26,17 +26,24 @@ class CsvDirStore(BaseDataStore[TabularData]):

def __init__(self, dir_path: Path, **csv_kwargs):
super().__init__()
self._dir_path = dir_path
self._dir_path = Path(dir_path)
self._csv_kwargs: Dict[str, Any] = csv_kwargs
self._header_rows: List[int] = [0]

def load(self) -> TabularData:
"""
Load all CSV files in a directory as tabular data.
Create a lazy loader for all CSV files in a directory and store them in a TabularData instance.
"""
data: Dict[str, pd.DataFrame] = {}

def lazy_csv_loader(csv_path: Path) -> Callable[[], pd.DataFrame]:
def csv_loader():
return pd.read_csv(filepath_or_buffer=csv_path, header=self._header_rows, **self._csv_kwargs)

return csv_loader

data: Dict[str, Callable[[], pd.DataFrame]] = {}
for path in self._dir_path.glob("*.csv"):
data[path.stem] = pd.read_csv(filepath_or_buffer=path, header=self._header_rows, **self._csv_kwargs)
data[path.stem] = lazy_csv_loader(path)

return TabularData(**data)

Expand Down
21 changes: 13 additions & 8 deletions src/power_grid_model_io/data_types/tabular_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
which supports unit conversions and value substitutions
"""

from typing import Dict, Iterable, Optional, Tuple, Union
from typing import Callable, Dict, Generator, Iterable, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -15,14 +15,16 @@
from power_grid_model_io.mappings.unit_mapping import UnitMapping
from power_grid_model_io.mappings.value_mapping import ValueMapping

LazyDataFrame = Callable[[], pd.DataFrame]


class TabularData:
"""
The TabularData class is a wrapper around Dict[str, Union[pd.DataFrame, np.ndarray]],
which supports unit conversions and value substitutions
"""

def __init__(self, **tables: Union[pd.DataFrame, np.ndarray]):
def __init__(self, **tables: Union[pd.DataFrame, np.ndarray, LazyDataFrame]):
"""
Tabular data can either be a collection of pandas DataFrames and/or numpy structured arrays.
The key word arguments will define the keys of the data.
Expand All @@ -34,12 +36,12 @@ def __init__(self, **tables: Union[pd.DataFrame, np.ndarray]):
**tables: A collection of pandas DataFrames and/or numpy structured arrays
"""
for table_name, table_data in tables.items():
if not isinstance(table_data, (pd.DataFrame, np.ndarray)):
if not isinstance(table_data, (pd.DataFrame, np.ndarray)) and not callable(table_data):
raise TypeError(
f"Invalid data type for table '{table_name}'; "
f"expected a pandas DataFrame or NumPy array, got {type(table_data).__name__}."
)
self._data: Dict[str, Union[pd.DataFrame, np.ndarray]] = tables
self._data: Dict[str, Union[pd.DataFrame, np.ndarray, LazyDataFrame]] = tables
self._units: Optional[UnitMapping] = None
self._substitution: Optional[ValueMapping] = None
self._log = structlog.get_logger(type(self).__name__)
Expand Down Expand Up @@ -73,7 +75,7 @@ def get_column(self, table_name: str, column_name: str) -> pd.Series:
Returns:
The required column, with unit conversions and value substitutions applied
"""
table_data = self._data[table_name]
table_data = self[table_name]

# If the index 'column' is requested, but no column called 'index' exist,
# return the index of the dataframe as if it were an actual column.
Expand Down Expand Up @@ -176,6 +178,8 @@ def __getitem__(self, table_name: str) -> Union[pd.DataFrame, np.ndarray]:
Returns: The 'raw' table data
"""
if callable(self._data[table_name]):
self._data[table_name] = self._data[table_name]()
return self._data[table_name]

def keys(self) -> Iterable[str]:
Expand All @@ -187,13 +191,14 @@ def keys(self) -> Iterable[str]:

return self._data.keys()

def items(self) -> Iterable[Tuple[str, Union[pd.DataFrame, np.ndarray]]]:
def items(self) -> Generator[Tuple[str, Union[pd.DataFrame, np.ndarray]], None, None]:
"""
Mimic the dictionary .items() function
Returns: An iterator over the table names and the raw table data
Returns: An generator of the table names and the raw table data
"""

# Note: PyCharm complains about the type, but it is correct, as an ItemsView extends from
# AbstractSet[Tuple[_KT_co, _VT_co]], which actually is compatible with Iterable[_KT_co, _VT_co]
return self._data.items()
for key in self._data:
yield key, self[key]

0 comments on commit f7f86f6

Please sign in to comment.