Skip to content

Commit

Permalink
Move to TF iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 26, 2024
1 parent f4f43d1 commit 7522789
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 38 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ dependencies = [
"optax>=0.1.3",
"surjectors>=0.3.0",
"tfp-nightly>=0.20.0.dev20230404",
"tensorflow==2.15.0",
"tensorflow-datasets==4.9.3",
"tqdm>=4.64.1"
]
dynamic = ["version"]
Expand Down
65 changes: 27 additions & 38 deletions sbijax/_src/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import namedtuple

import chex
from jax import lax
import tensorflow as tf
from jax import Array
from jax import numpy as jnp
from jax import random as jr

Expand Down Expand Up @@ -34,20 +34,20 @@ def __call__(self, idx, idxs=None): # noqa: D102

# pylint: disable=missing-function-docstring
def as_batch_iterators(
rng_key: chex.PRNGKey, data: named_dataset, batch_size, split, shuffle
rng_key: Array, data: named_dataset, batch_size, split, shuffle
):
"""Create two data batch iterators from a data set.
Args:
rng_key: random key
data: a named tuple containing all dat
batch_size: batch size
rng_key: a jax random key
data: a named tuple with elements 'y' and 'theta' all data
batch_size: size of each batch
split: fraction of data to use for training data set. Rest is used
for validation data set.
shuffle: shuffle the data set or no
Returns:
two iterators
returns two iterators
"""
n = data.y.shape[0]
n_train = int(n * split)
Expand All @@ -67,41 +67,30 @@ def as_batch_iterators(


# pylint: disable=missing-function-docstring
def as_batch_iterator(
rng_key: chex.PRNGKey, data: named_dataset, batch_size, shuffle
):
def as_batch_iterator(rng_key: Array, data: named_dataset, batch_size, shuffle):
"""Create a data batch iterator from a data set.
Args:
rng_key: random key
data: a named tuple containing all dat
batch_size: batch size
rng_key: a jax random key
data: a named tuple with elements 'y' and 'theta' all data
batch_size: size of each batch
shuffle: shuffle the data set or no
Returns:
an iterator
a tensorflow iterator
"""
n = data.y.shape[0]
if n < batch_size:
num_batches = 1
batch_size = n
elif n % batch_size == 0:
num_batches = int(n // batch_size)
else:
num_batches = int(n // batch_size) + 1

idxs = jnp.arange(n)
if shuffle:
idxs = jr.permutation(rng_key, idxs)

def get_batch(idx, idxs=idxs):
start_idx = idx * batch_size
step_size = jnp.minimum(n - start_idx, batch_size)
ret_idx = lax.dynamic_slice_in_dim(idxs, idx * batch_size, step_size)
batch = {
name: lax.index_take(array, (ret_idx,), axes=(0,))
for name, array in zip(data._fields, data)
}
return batch

return DataLoader(num_batches, idxs, get_batch)
# hack, cause the tf stuff doesn't support jax keys :)
max_int32 = jnp.iinfo(jnp.int32).max
seed = jr.randint(rng_key, shape=(), minval=0, maxval=max_int32)
itr = tf.data.Dataset.from_tensor_slices(data)
itr = (
itr.shuffle(
10 * batch_size,
seed=int(seed),
reshuffle_each_iteration=shuffle,
)
.batch(batch_size)
.prefetch(buffer_size=batch_size)
.as_numpy_iterator()
)
return itr

0 comments on commit 7522789

Please sign in to comment.