Skip to content

Commit

Permalink
add dataclass support to PyTree
Browse files Browse the repository at this point in the history
tests
  • Loading branch information
PhilipVinc committed Nov 19, 2023
1 parent 9fecbc4 commit f0297a5
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 4 deletions.
74 changes: 71 additions & 3 deletions netket/utils/struct/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from functools import partial

import dataclasses
import warnings
from dataclasses import MISSING

from flax import serialization
Expand All @@ -44,6 +45,7 @@

from .utils import _set_new_attribute, _create_fn, get_class_globals
from .fields import _cache_name, Uninitialized, field, CachedProperty
from .pytree import Pytree

try:
from dataclasses import _FIELDS
Expand Down Expand Up @@ -197,7 +199,7 @@ def purge_cache_fields(clz):


def attach_preprocess_init(
data_clz, *, globals=None, init_doc=MISSING, cache_hash=False
data_clz, *, globals=None, init_doc=MISSING, cache_hash=False, is_pytree=False
):
if globals is None:
globals = {}
Expand All @@ -211,7 +213,6 @@ def _preprocess_args_default(self, *args, **kwargs):
args, kwargs = getattr(super(data_clz, self), _PRE_INIT_NAME)(
*args, **kwargs
)

return args, kwargs

_set_new_attribute(data_clz, _PRE_INIT_NAME, _preprocess_args_default)
Expand All @@ -227,6 +228,10 @@ def _preprocess_args_default(self, *args, **kwargs):
body_lines = [
"if not __skip_preprocess:",
f"\targs, kwargs = {self_name}.{_PRE_INIT_NAME}(*args, **kwargs)",
"if True:" if is_pytree else "if False:",
"\t_args_pytree, _kwargs_pytree = kwargs['__base_init_args']",
"\tdel kwargs['__base_init_args']",
"\tsuper(data_class, self).__init__(*_args_pytree, **_kwargs_pytree)",
f"{self_name}.{_DATACLASS_INIT_NAME}(*args, **kwargs)",
"if __precompute_cached_properties:",
f"\t{self_name}.{PRECOMPUTE_CACHED_PROPERTY_NAME}()",
Expand All @@ -238,6 +243,8 @@ def _preprocess_args_default(self, *args, **kwargs):
f"BUILTINS.object.__setattr__({self_name},{_hash_cache_name(data_clz.__name__)!r},Uninitialized)"
)

globals["data_class"] = data_clz

fun = _create_fn(
"__init__",
[
Expand Down Expand Up @@ -293,6 +300,13 @@ def dataclass(clz=None, *, init_doc=MISSING, cache_hash=False, _frozen=True):
This behaves as a flax dataclass, that is a Frozen python dataclass, with a twist!
See their documentation for standard behaviour.
.. warning::
This decorator should be used together with classes inheriting from
:ref:`netket.utils.struct.Pytree`. While simple cases will work
for now, it is not guaranteed that the behaviour will be always correct
and stable.
The new functionalities added by NetKet are:
- it is possible to define a method `__pre_init__(*args, **kwargs) ->
Tuple[Tuple,Dict]` that processes the arguments and keyword arguments provided
Expand Down Expand Up @@ -324,21 +338,75 @@ def dataclass(clz=None, *, init_doc=MISSING, cache_hash=False, _frozen=True):
dataclass, init_doc=init_doc, cache_hash=cache_hash, _frozen=_frozen
)

is_pytree = Pytree in clz.__mro__

if is_pytree:
if not (clz._pytree__class_is_mutable ^ _frozen):
raise ValueError(
f"Inheriting from a mutable={clz._pytree__class_is_mutable} but _frozen={_frozen}"
)
# let the base class handle the frozeness
_frozen = False

if _PRE_INIT_NAME in clz.__dict__:
msg = f"""
You defined `__pre_init__(*args, **kwargs)` in a netket
dataclass (a class decorated with @nk.utils.struct.dataclass) which
inherits from a `nk.utils.struct.Pytree`.
The class is {type(clz)}.
This behaviour is not supported and might break. Please remove
the decorator and just inherit from the base class, defining
a standard `__init__` method which calls `super().__init__(...)`
as usual.
If you need help, reach out with us.
"""
warnings.warn(msg, category=FutureWarning, stacklevel=1)

if "__post_init__" in clz.__dict__:
msg = f"""
You defined `__post_init__(self)` in a netket
dataclass (a class decorated with @nk.utils.struct.dataclass) which
inherits from a `nk.utils.struct.Pytree`.
The class is {type(clz)}.
This behaviour is not supported and might break. Please remove
the decorator and just inherit from the base class, defining
a standard `__init__` method which calls `super().__init__(...)`
as usual.
If you need help, reach out with us.
"""
warnings.warn(msg, category=FutureWarning, stacklevel=1)

# get globals of the class to put generated methods in there
_globals = get_class_globals(clz)
_globals["Uninitialized"] = Uninitialized
# proces all cached properties
process_cached_properties(clz, globals=_globals)
# create the dataclass
data_clz = dataclasses.dataclass(frozen=_frozen)(clz)

purge_cache_fields(data_clz)
# attach the custom preprocessing of init arguments
attach_preprocess_init(
data_clz, globals=_globals, init_doc=init_doc, cache_hash=cache_hash
data_clz,
globals=_globals,
init_doc=init_doc,
cache_hash=cache_hash,
is_pytree=is_pytree,
)
if cache_hash:
replace_hash_method(data_clz, globals=_globals)

# if it's an 'auto-style PyTree', use standard dataclass-logic
# and do not register it with jax/flax
if is_pytree:
return data_clz

# flax stuff: identify states
meta_fields = []
data_fields = []
Expand Down
29 changes: 29 additions & 0 deletions netket/utils/struct/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,35 @@ def __init_subclass__(cls, mutable: bool = False, dynamic_nodes: bool = False):
partial(cls._from_flax_state_dict, cls._pytree__static_fields),
)

def __pre_init__(self, *args, **kwargs):
# Default implementation of __pre_init__, used by netket's
# dataclasses for preinitialisation shuffling of parameters.
#
# This is necessary for PyTrees that are subclassed by a dataclass
# (like a user-implemented sampler using legacy logic).
#
# This class takes out all arguments and kw-arguments that are
# directed to the PyTree from a processing and 'hides' them
# in a proprietary kwargument for later manipulation.
#
# This is necessary so we call the dataclass init only with
# the arguments that it needs.
kwargs_dataclass = {}
kwargs_pytree = {}
for k, v in kwargs.items():
if k in self.__dataclass_fields__.keys():
kwargs_dataclass[k] = v
else:
kwargs_pytree[k] = v

signature_pytree = (args, kwargs_pytree)
kwargs_dataclass["__base_init_args"] = signature_pytree

return (), kwargs_dataclass

def __post_init__(self):
pass

@classmethod
def _pytree__flatten(
cls,
Expand Down
109 changes: 108 additions & 1 deletion test/utils/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from flax import serialization

from netket.utils.struct import Pytree, field, static_field
from netket.utils.struct import Pytree, dataclass, field, static_field


class TestPytree:
Expand Down Expand Up @@ -55,6 +55,28 @@ def __init__(self, y) -> None:
):
pytree.x = 4

def test_immutable_pytree_dataclass(self):
@dataclass(_frozen=True)
class Foo(Pytree):
y: int = field()
x: int = static_field(default=2)

pytree = Foo(y=3)

leaves = jax.tree_util.tree_leaves(pytree)
assert leaves == [3]

pytree = jax.tree_map(lambda x: x * 2, pytree)
assert pytree.x == 2
assert pytree.y == 6

pytree = pytree.replace(x=3)
assert pytree.x == 3
assert pytree.y == 6

with pytest.raises(AttributeError):
pytree.x = 4

def test_jit(self):
class Foo(Pytree):
a: int
Expand Down Expand Up @@ -181,6 +203,20 @@ class Foo(Pytree):
with pytest.raises(ValueError, match="Trying to replace unknown fields"):
Foo().replace(y=1)

def test_dataclass_inheritance(self):
@dataclass
class A(Pytree):
a: int = 1
b: int = static_field(default=2)

@dataclass
class B(A):
c: int = 3

pytree = B()
leaves = jax.tree_util.tree_leaves(pytree)
assert leaves == [1, 3]

def test_pytree_with_new(self):
class A(Pytree):
a: int
Expand Down Expand Up @@ -259,3 +295,74 @@ def __init__(self, x):

with pytest.raises(AttributeError, match=r"Cannot add new fields to"):
foo.y = 2

def test_pytree_dataclass(self):
with pytest.raises(ValueError):

@dataclass
class _Foo(Pytree, mutable=True):
y: int = field()
x: int = static_field(default=2)

@dataclass(_frozen=False)
class Foo(Pytree, mutable=True):
y: int = field()
x: int = static_field(default=2)

pytree: Foo = Foo(y=3)

leaves = jax.tree_util.tree_leaves(pytree)
assert leaves == [3]

pytree = jax.tree_map(lambda x: x * 2, pytree)
assert pytree.x == 2
assert pytree.y == 6

pytree = pytree.replace(x=3)
assert pytree.x == 3
assert pytree.y == 6

# test mutation
pytree.x = 4
assert pytree.x == 4

def test_dataclass_inheritance(self):
class A(Pytree):
y: int = field()
x: int = static_field(default=2)

def __init__(self, x, y):
self.x = x
self.y = y

@dataclass
class B(A):
z: int

b = B(1, 2, z=5)

assert b.x == 1
assert b.y == 2
assert b.z == 5

assert jax.tree_util.tree_leaves(b) == [2, 5]

# pre init
with pytest.warns(FutureWarning):

@dataclass
class B(A):
z: int

def __pre_init__(self, x, y, kk):
args, kwargs = super().__pre_init__(x, y)
kwargs["z"] = kk
return args, kwargs

b = B(1, 2, kk=5)

assert b.x == 1
assert b.y == 2
assert b.z == 5

assert jax.tree_util.tree_leaves(b) == [2, 5]

0 comments on commit f0297a5

Please sign in to comment.