Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further training with heterogeneous datasets #423

Closed
danielphan2003 opened this issue Mar 23, 2024 · 4 comments
Closed

Further training with heterogeneous datasets #423

danielphan2003 opened this issue Mar 23, 2024 · 4 comments
Labels
wontfix This will not be worked on

Comments

@danielphan2003
Copy link

Is your feature request related to a problem? Please describe.
Alternative title: Federated Learning usecase with the same model but different parameters shape due to training with heterogeneous datasets

I'm using Flower, and my workflow is as below:

  1. Server push out a common model to clients (this is trained on a dataset with full labels).
  2. Client start training on their own dataset where they might only have data for certain labels, and no other data is unlabeled.
  3. Server get the model parameters (numpy arrays) on each clients
  4. Server aggregate into another model parameters
  5. Server push the aggregated parameters to clients
  6. Client restore from the aggregated parameters and return to step 2.

I'm used to train on simple models with simple data preprocessing (sklearn's StandardScaler) and their parameters shape always remains the same regardless of the dataset. This is not the case for Pytorch Tabular as whenever I save a model trained on a dataset and reuse it to train another slightly different one i.e missing all rows of a label, both share the same layers but are not the same shape-wise.

Describe the solution you'd like
I hope that model parameters shape can be consistent between training with different datasets, so that aggregation works without needing to pad model parameters with smaller shape to fit bigger ones.

Describe alternatives you've considered
Padding only works with aggregation, and while restoring parameters on clients works, training on it would cause shape mismatch as the original model parameters shape is different from the aggregated ones.

Additional context
Training with client mismatch causes: RuntimeError: The size of tensor a (85) must match the size of tensor b (84) at non-singleton dimension 1

I used the following to restore parameters to model:

# flt.NDArrays = list[np.NDArray]

def set_pruned_model_parameters(model: torch.nn.Module, path: Path | flt.NDArrays) -> None:
	""" Set model parameters without relying on torch.nn.Module errors out due to shape mismatch """
    if isinstance(path, Path) or isinstance(path, BytesIO):
        ckpt = torch.load(path, map_location=lambda storage, loc: storage)  # type: ignore
        state_dict: OrderedDict[str, Any] = ckpt.get("state_dict") or ckpt
    else:
        params_dict = zip(model.state_dict().keys(), path)
        state_dict = OrderedDict(
            {k: torch.tensor(v, dtype=torch.float, device="cpu") for k, v in params_dict}
        )

    for name, param in state_dict.items():
        submodules = name.split(".")
        param = torch.nn.Parameter(param, requires_grad=False)
        set_attr(model, submodules, param)


def set_model_parameters(model: torch.nn.Module, parameters: flt.NDArrays) -> None:
	""" Set model parameters by relying on torch.nn.Module"""
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

Reproducing notebook: https://colab.research.google.com/drive/11BD1XqSgc5a4k0YxxfyeFGfr0_KOUj0h?usp=sharing

@danielphan2003
Copy link
Author

danielphan2003 commented Mar 24, 2024

Specifying embedding_dims in {GANDALF,CategoryEmbeddingModel}Config and saving model as inference only appears to produce the same parameters shape when fitting. I will try it out today to see if it works...

UPDATE: Actually pinning model_config.embedding_dims solves does not solve this issue yet. A comparison between head.layers.0.weight and head.layers.0.bias shows that these layers depend on the unique labels available in clients, which is still an improvement over not specifying embedding_dims at all:

[
    ('change', ['head.layers.0.weight'], ((34, 512), (21, 512))),
    ('change', ['head.layers.0.bias'], ((34,), (21,)))
]

(still recommend saving model as inference only as it would otherwise bundle the server-side train dataset which we don't need on clients):

model_config: DictConfig = OmegaConf.load(f"{model_path}/config.yml")
datamodule: TabularDatamodule = joblib.load(f"{model_path}/datamodule.sav")
model_config.embedding_dims = copy.copy(datamodule._inferred_config.embedding_dims)
del datamodule # since we don't need it anymore

@danielphan2003
Copy link
Author

For now you need to extends TabularModel.prepare_model to prevent it from regenerating InferredConfig based on the client's data module, and use the server's instead. It also wasn't very obvious how we can 1) restore saved model and reuse it for further training as everytime fit is called the model is recreated from scratch and 2) force it to use transform instead of fit_transform in preprocess_data, which I will post how to workaround this in the next few days.

@manujosephv
Copy link
Owner

I guess the library was designed to make it easy for a non-DL user also and so we infer a lot of things from the data. While this is useful for many usecases, in use cases such as yours it's definitely a problem.

I did introduce a low-level API to give more flexibility to users where the fit is split into three sub tasks - prepare_dataloader, prepare_model, and train. Maybe you can explore that for your usecase?

Also, there are ways to save and load the model (different ways). Can you check the tutorials and documentation for those methods?

But in the long run, there may need some tweaks in the library code for this to work seamlessly. If you can figure out a neat way(without disrupting existing workflows), then nothing like it!

Copy link

stale bot commented May 31, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix This will not be worked on label May 31, 2024
@stale stale bot closed this as completed Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants