Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qlora merge lora weights error #350

Open
zousss opened this issue Apr 10, 2024 · 0 comments
Open

qlora merge lora weights error #350

zousss opened this issue Apr 10, 2024 · 0 comments

Comments

@zousss
Copy link

zousss commented Apr 10, 2024

RuntimeError Traceback (most recent call last)
Cell In[9], line 4
1 from finetune_visualglm import FineTuneVisualGLMModel
2 import argparse
----> 4 model, args = FineTuneVisualGLMModel.from_pretrained('/kaggle/working/checkpoints/finetune-visualglm-6b-04-09-09-10',
5 args=argparse.Namespace(
6 fp16=True,
7 skip_init=True,
8 use_gpu_initialization=True,
9 device='cuda',
10 ))
11 model.get_mixin('lora').merge_lora()
12 args.layer_range = []

File /opt/conda/lib/python3.10/site-packages/sat/model/base_model.py:207, in BaseModel.from_pretrained(cls, name, args, home_path, url, prefix, build_only, overwrite_args, **kwargs)
205 model = get_model(args, cls, **kwargs)
206 if not build_only:
--> 207 load_checkpoint(model, args, load_path=model_path, prefix=prefix)
208 return model, args

File /opt/conda/lib/python3.10/site-packages/sat/training/model_io.py:238, in load_checkpoint(model, args, load_path, prefix)
235 module = model
237 # only load module, other hyperparameters are just for recording.
--> 238 missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
239 if len(unexpected_keys) > 0:
240 print_rank0(
241 f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2138, in Module.load_state_dict(self, state_dict, strict, assign)
2131 out = hook(module, incompatible_keys)
2132 assert out is None, (
2133 "Hooks registered with register_load_state_dict_post_hook are not"
2134 "expected to return new values, if incompatible_keys need to be modified,"
2135 "it should be done inplace."
2136 )
-> 2138 load(self, state_dict)
2139 del load
2141 if strict:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

[... skipping similar frames: Module.load_state_dict.<locals>.load at line 2126 (3 times)]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2126, in Module.load_state_dict..load(module, local_state_dict, prefix)
2124 child_prefix = prefix + name + '.'
2125 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2126 load(child, child_state_dict, child_prefix)
2128 # Note that the hook can modify missing_keys and unexpected_keys.
2129 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2120, in Module.load_state_dict..load(module, local_state_dict, prefix)
2118 if assign:
2119 local_metadata['assign_to_params_buffers'] = assign
-> 2120 module._load_from_state_dict(
2121 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
2122 for name, child in module._modules.items():
2123 if child is not None:

File /opt/conda/lib/python3.10/site-packages/sat/model/finetune/lora2.py:47, in HackLinearNF4._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
45 def load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
46 if prefix + 'weight' in state_dict:
---> 47 self.weight.data.copy
(state_dict[prefix+'weight'])
48 if self.weight.data.dtype == torch.uint8:
49 copy_nested_list(state_dict[prefix+'quant_state'], self.weight.quant_state)

RuntimeError: output with shape [25165824, 1] doesn't match the broadcast shape [25165824, 0]

How can I solve this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant