Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.save does not directly work with unexpected key(s) #258

Open
chuanqixu opened this issue Apr 19, 2024 · 0 comments
Open

torch.save does not directly work with unexpected key(s) #258

chuanqixu opened this issue Apr 19, 2024 · 0 comments
Assignees

Comments

@chuanqixu
Copy link

Problem

For a model with QuantumModule, torch.save(model.state_dict(), "model.pt") and model.load_state_dict(torch.load("model.pt")) may not work because state keys are lazily created during the forwarding process.

Example

I used the Model1 in Quantum Convolution (Quanvolution) example. The detailed code is below:

import torchquantum as tq

import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random

from torchquantum.dataset import MNIST
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchquantum.layer import U3CU3Layer0


class TrainableQuanvFilter(tq.QuantumModule):
    def __init__(self):
        super().__init__()
        self.n_wires = 4
        self.encoder = tq.GeneralEncoder(
            [
                {"input_idx": [0], "func": "ry", "wires": [0]},
                {"input_idx": [1], "func": "ry", "wires": [1]},
                {"input_idx": [2], "func": "ry", "wires": [2]},
                {"input_idx": [3], "func": "ry", "wires": [3]},
            ]
        )

        self.arch = {"n_wires": self.n_wires, "n_blocks": 5, "n_layers_per_block": 2}
        self.q_layer = U3CU3Layer0(self.arch)
        self.measure = tq.MeasureAll(tq.PauliZ)

    def forward(self, x, use_qiskit=False):
        bsz = x.shape[0]
        qdev = tq.QuantumDevice(self.n_wires, bsz=bsz, device=x.device)
        x = F.avg_pool2d(x, 6).view(bsz, 4, 4)
        size = 4
        stride = 2
        x = x.view(bsz, size, size)

        data_list = []

        for c in range(0, size, stride):
            for r in range(0, size, stride):
                data = torch.transpose(
                    torch.cat(
                        (x[:, c, r], x[:, c, r + 1], x[:, c + 1, r], x[:, c + 1, r + 1])
                    ).view(4, bsz),
                    0,
                    1,
                )
                if use_qiskit:
                    data = self.qiskit_processor.process_parameterized(
                        qdev, self.encoder, self.q_layer, self.measure, data
                    )
                else:
                    self.encoder(qdev, data)
                    self.q_layer(qdev)
                    data = self.measure(qdev)

                data_list.append(data.view(bsz, 4))

        # transpose to (bsz, channel, 2x2)
        result = torch.transpose(
            torch.cat(data_list, dim=1).view(bsz, 4, 4), 1, 2
        ).float()

        return result

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.qf = TrainableQuanvFilter()
        self.linear = torch.nn.Linear(16, 4)

    def forward(self, x, use_qiskit=False):
        x = x.view(-1, 28, 28)
        x = self.qf(x)
        x = x.reshape(-1, 16)
        x = self.linear(x)
        return F.log_softmax(x, -1)


def train(dataflow, model, device, optimizer):
    for feed_dict in dataflow["train"]:
        inputs = feed_dict["image"].to(device)
        targets = feed_dict["digit"].to(device)

        outputs = model(inputs)
        loss = F.nll_loss(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss: {loss.item()}", end="\r")


def valid_test(dataflow, split, model, device, qiskit=False):
    target_all = []
    output_all = []
    with torch.no_grad():
        for feed_dict in dataflow[split]:
            inputs = feed_dict["image"].to(device)
            targets = feed_dict["digit"].to(device)

            outputs = model(inputs, use_qiskit=qiskit)

            target_all.append(targets)
            output_all.append(outputs)
        target_all = torch.cat(target_all, dim=0)
        output_all = torch.cat(output_all, dim=0)

    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    loss = F.nll_loss(output_all, target_all).item()

    print(f"{split} set accuracy: {accuracy}")
    print(f"{split} set loss: {loss}")

    return accuracy, loss


n_epochs = 1

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

dataset = MNIST(
    root="./mnist_data",
    train_valid_split_ratio=[0.9, 0.1],
    digits_of_interest=[0, 1, 2, 3],
    n_test_samples=300,
    n_train_samples=500,
)

dataflow = dict()
for split in dataset:
    sampler = torch.utils.data.RandomSampler(dataset[split])
    dataflow[split] = torch.utils.data.DataLoader(
        dataset[split],
        batch_size=10,
        sampler=sampler,
        num_workers=8,
        pin_memory=True,
    )

device = torch.device("cpu")

model = Model().to(device)

# print(f"training model...")
# optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
# scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
# for epoch in range(1, n_epochs + 1):
#     # train
#     print(f"Epoch {epoch}:")
#     train(dataflow, model, device, optimizer)
#     print(optimizer.param_groups[0]["lr"])
#     # valid
#     accu, loss = valid_test(dataflow, "test", model, device)
#     scheduler.step()

The save and load example in Save and Load QNN models may not directly work. Without commenting on the training part, I trained the model and saved the state_dict with:

torch.save(model.state_dict(), "model.pt")

When I tried to load a newly created model, the model may not have some keys in state_dict because these keys are lazily created. For example, load a new model with:

model2 = Model().to(device)
model2.load_state_dict(torch.load("model.pt"))

The error is:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], [line 2](vscode-notebook-cell:?execution_count=4&line=2)
      [1](vscode-notebook-cell:?execution_count=4&line=1) model2 = Model().to(device)
----> [2](vscode-notebook-cell:?execution_count=4&line=2) model2.load_state_dict(torch.load("model.pt"))

File [d:\Programming\Anaconda3\envs\torchquantum\lib\site-packages\torch\nn\modules\module.py:2153](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2153), in Module.load_state_dict(self, state_dict, strict, assign)
   [2148](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2148)         error_msgs.insert(
   [2149](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2149)             0, 'Missing key(s) in state_dict: {}. '.format(
   [2150](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2150)                 ', '.join(f'"{k}"' for k in missing_keys)))
   [2152](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2152) if len(error_msgs) > 0:
-> [2153](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2153)     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   [2154](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2154)                        self.__class__.__name__, "\n\t".join(error_msgs)))
   [2155](file:///D:/Programming/Anaconda3/envs/torchquantum/lib/site-packages/torch/nn/modules/module.py:2155) return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Model:
	Unexpected key(s) in state_dict: "qf.q_layer.q_device.state", "qf.q_layer.q_device.states".

Potential Solution

The cause is that during the forwarding process, the model may create new features.

For example, in the above example, in TrainableQuanvFilter, self.q_layer is created inside __init__ with self.q_layer = U3CU3Layer0(self.arch). torchquantum/layer/layers/u3_layer.py

U3CU3Layer0 is inherited from LayerTemplate0 and its forward() method is also inherited from LayerTemplate0. Inside forward(), a new feature is appended to the object, and this will only be appended when forwarding the model: torchquantum/layer/layers/layers.py

    @tq.static_support
    def forward(self, q_device: tq.QuantumDevice):
        self.q_device = q_device
        for k in range(len(self.layers_all)):
            self.layers_all[k](q_device)

One solution is to finish creating all features during __init__, but I am not familiar with torchquantum's design principle. Since forward() requires q_device as an input, and this input is to be assigned to the feature, it may be designed to be used for many devices. So this change may require a large interface change, and may need to create one layer object for each device instead of reusing only one layer object.

Another way is every time to load the state_dict, forward the model once. This may be time-consuming when the model is large.

Related Issues

These issues may be related to this issue.

#210 is not resolved yet.
#49 provided the save and load example, but as explained, this will not work. In the example, it directly saves and loads, so keys in state_dict are not missing.

@GenericP3rson GenericP3rson self-assigned this Apr 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants