Skip to content

Commit

Permalink
wip: support serializing sample
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Nov 29, 2023
1 parent 4f0fbd1 commit 2cb6d20
Show file tree
Hide file tree
Showing 4 changed files with 445 additions and 14 deletions.
5 changes: 4 additions & 1 deletion python/python/lance/lance/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterator, List, Optional, Union
from typing import IO, Any, Iterator, List, Optional, Union

import pyarrow as pa

Expand Down Expand Up @@ -30,3 +30,6 @@ class DatasetSample:
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[pa.Array]: ...
def __getitem__(self, item: Any) -> Union[pa.Array, DatasetSample]: ...
def serialize_into(self, path_or_file: Union[str, IO[bytes]]) -> None: ...
@staticmethod
def deserialize_from(path_or_file: Union[str, IO[bytes]]) -> DatasetSample: ...
49 changes: 46 additions & 3 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,54 @@ class SampleMetrics(NamedTuple):
def build_shuffle_sample(
dataset: lance.LanceDataset,
params: Optional[SampleParams] = None,
**kwargs,
*,
predicate: Optional[str] = None,
batch_size: int = 32,
shuffle: bool = True,
sample_rate: Optional[float] = None,
seed: Optional[int] = None,
) -> DatasetSample:
"""Build a pre-computed sample from the dataset."""
"""
Build a pre-computed sample from the dataset.
Parameters
----------
dataset : lance.LanceDataset
The dataset to sample.
params : Optional[SampleParams], optional
The parameters to use for sampling, by default None. The parameter object
has the same arguments as the keyword arguments of this function. If this
is provided, the keyword arguments are ignored.
predicate : Optional[str], optional
A SQL filter to apply to the dataset prior to sampling and shuffling.
batch_size : int, default 32
The max size of a batch to read as contiguous rows. Smaller values mean more
randomization, but also more IO overhead. Defaults to 32.
shuffle : bool, default True
Whether to shuffle the batches after sampling to randomize the order.
This is useful for training, if the order of rows isn't already random.
sample_rate : Optional[float], optional
The number of rows to sample. If None, all rows are sampled.
This is applied *after* the predicate is applied. Fewer rows may be sampled
if the predicate filters the dataset down to fewer rows than desired count.
Must be between 0 and 1.
seed : Optional[int], optional
The random seed to use for sampling.
Use this to ensure that the same sample is generated across multiple runs.
Returns
-------
DatasetSample
A pre-computed sample of the dataset.
"""
if params is None:
params = SampleParams(**kwargs)
params = SampleParams(
predicate=predicate,
batch_size=batch_size,
shuffle=shuffle,
sample_rate=sample_rate,
seed=seed,
)

# python/src/sampler.rs
return _build_shuffle_sample(dataset._ds, params)
142 changes: 140 additions & 2 deletions python/python/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from itertools import product
from pathlib import Path
from typing import Optional

import lance
import numpy as np
import pyarrow as pa
from lance.sampler import maybe_sample
import pytest

import lance
from lance.dataset import LanceDataset
from lance.lance import DatasetSample
from lance.sampler import SampleParams, build_shuffle_sample, maybe_sample


def test_sample_dataset(tmp_path: Path):
Expand All @@ -42,3 +49,134 @@ def test_sample_dataset(tmp_path: Path):
assert isinstance(large_scan[0], pa.RecordBatch)
assert large_scan[0].schema == pa.schema([pa.field("vec", fsl.type)])
assert large_scan[0].num_rows == 128


@pytest.fixture(scope="module")
def readonly_dataset(tmpdir_factory):
tmp_path = Path(tmpdir_factory.mktemp("data"))

nrows = 10240
ndims = 32

ids = pa.array(np.arange(nrows))
data = np.random.random(nrows * ndims).astype("f")

fsl = pa.FixedSizeListArray.from_arrays(data, ndims)
tbl = pa.Table.from_arrays([ids, fsl], ["id", "vec"])

return lance.write_dataset(tbl, tmp_path / "data.lance")


def test_sample_params(readonly_dataset: LanceDataset):
params = SampleParams(
"id > 20",
batch_size=1024,
shuffle=True,
sample_rate=0.5,
seed=42,
)

sample = build_shuffle_sample(readonly_dataset, params)
assert sample.params == params

sample = build_shuffle_sample(readonly_dataset, **dataclasses.asdict(params))
assert sample.params == params

assert repr(sample.params) in repr(sample)


def test_sample_num_rows(readonly_dataset: LanceDataset):
sample = build_shuffle_sample(readonly_dataset)
assert sample.num_rows == len(readonly_dataset)
assert sample.metrics.dataset_size == len(readonly_dataset)
assert sample.metrics.matched_rows == len(readonly_dataset)
assert sample.metrics.sampled_rows == len(readonly_dataset)

sample = build_shuffle_sample(
readonly_dataset, predicate="id >= 20", batch_size=1024
)
assert sample.num_rows == len(readonly_dataset) - 20
assert sample.metrics.dataset_size == len(readonly_dataset)
assert sample.metrics.matched_rows == len(readonly_dataset) - 20
assert sample.metrics.sampled_rows == len(readonly_dataset) - 20

for i in range(len(sample)):
indices = sample[i]
assert len(indices) <= 1024
assert all(20 <= idx.as_py() < len(readonly_dataset) for idx in indices)

sample_sliced = sample[1:]
assert len(sample_sliced) == len(sample) - 1
assert sample_sliced.num_rows == len(readonly_dataset) - 20 - len(sample[0])


@pytest.mark.parametrize("batch_size", [32, 1024])
@pytest.mark.parametrize("sample_rate", [None, 0.5, 1.0])
@pytest.mark.parametrize("shuffle", [True, False])
def test_shuffle_sample_slice(
batch_size: int,
sample_rate: Optional[float],
shuffle: bool,
readonly_dataset: LanceDataset,
):
params = SampleParams(
"id > 20",
batch_size=batch_size,
shuffle=shuffle,
sample_rate=sample_rate,
seed=42,
)
sample = build_shuffle_sample(readonly_dataset, params)

materialized = list(iter(sample))
assert len(materialized) == len(sample)
assert sum(len(arr) for arr in materialized) == sample.num_rows

# Check that materializing a slice of a sample gives the save result as
# slicing the full materialized sample.
starts = range(len(sample))
lengths = [1, 4, len(sample)]
step = [None, 1, 2]
for start, length, step in product(starts, lengths, step):
stop = start + length
sample_sliced = sample[start:stop:step]
materialized_sliced = list(iter(sample_sliced))

assert materialized_sliced == materialized[start:stop:step]


def test_shuffle_sample_serialize(readonly_dataset: LanceDataset, tmp_path: Path):
params = SampleParams(
"id > 20",
batch_size=1024,
shuffle=True,
sample_rate=0.5,
seed=42,
)
sample = build_shuffle_sample(readonly_dataset, params)

# Roundtrip with a file path.
path = str(tmp_path / "sample.tar.gz")
sample.serialize_into(path)
sample_deserialized = DatasetSample.deserialize_from(path)
assert sample == sample
assert sample == sample_deserialized

# Read with a file object
with open(path, "rb") as f:
sample_deserialized = DatasetSample.deserialize_from(f)
assert sample == sample_deserialized

# Roundtrip with a file object.
path = str(tmp_path / "sample2.tar.gz")
with open(path, "wb") as f:
sample.serialize_into(f)

with open(path, "rb") as f:
sample_deserialized = DatasetSample.deserialize_from(f)

assert sample == sample_deserialized

with pytest.raises(IOError):
with open(path, "r") as f:
sample.serialize_into(f)

0 comments on commit 2cb6d20

Please sign in to comment.