微调模型的导入 #3379
-
ChatGLM3-6B通过 llama-factory进行模型lora的微调训练,训练后的模型权重和参数文件如何让Langchain-Chatchat这个项目调用。 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
嘿,@liuyuhao412!很高兴再次看到你的提问🚀 根据我在Langchain-Chatchat仓库中找到的相关问题和解决方案,你可以按照以下步骤来加载和使用通过llama-factory进行lora微调训练后的ChatGLM3-6B模型:
def load_ptuning_model(self, model_path: str,ptuning_checkpoint :str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=self.use_fast_tokenizer,
revision=revision,
trust_remote_code=True,
)
except TypeError:
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, revision=revision, trust_remote_code=True
)
try:
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
except NameError:
model = AutoModel.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model.transformer.prefix_encoder.float()
return model, tokenizer
class ModelWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
no_register: bool,
device: str,
num_gpus: int,
max_gpu_memory: str,
revision: str = None,
ptuning_checkpoint: str =None,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
stream_interval: int = 2,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
seed: Optional[int] = None,
debug: bool = False, 这些修改将使Langchain-Chatchat能够加载和使用P-Tuning v2微调后的模型。你可以参考已关闭的问题:V0.2.4以上版本如何加载ptunig-v2的微调模型?。 然而,需要注意的是,Langchain-Chatchat并不直接支持GLM3的P-Tuning v2微调,因为它不在peft结构中。一旦ChatGLM3团队重写了微调代码,使用新代码训练的模型将可以使用。你可以参考已关闭的问题:[FEATURE] GLM3 的 P-Tuning v2 微调能支持吗。 以上信息应该对于在Langchain-Chatchat中使用ChatGLM-6B Sources
|
Beta Was this translation helpful? Give feedback.
查看wiki