-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow embeddings and loras to have different names #6053
base: main
Are you sure you want to change the base?
Changes from all commits
bbecb99
18098cc
befcf77
a7b8bbc
90f32d1
653a8f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import json | ||
import os | ||
import re | ||
from pathlib import Path | ||
from typing import Any, Dict, Literal, Optional, Union | ||
|
@@ -145,7 +146,8 @@ def probe( | |
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") | ||
|
||
probe = probe_class(model_path) | ||
|
||
model_path = probe.model_path | ||
format_type = probe.get_format() | ||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path | ||
fields["source"] = fields.get("source") or model_path.as_posix() | ||
fields["key"] = fields.get("key", uuid_string()) | ||
|
@@ -159,7 +161,8 @@ def probe( | |
fields["description"] = ( | ||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}" | ||
) | ||
fields["format"] = fields.get("format") or probe.get_format() | ||
fields["format"] = fields.get("format") or format_type | ||
|
||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) | ||
|
||
fields["default_settings"] = fields.get("default_settings") | ||
|
@@ -643,17 +646,19 @@ def _guess_name(self) -> str: | |
return name | ||
|
||
|
||
class TextualInversionFolderProbe(FolderProbeBase): | ||
def get_format(self) -> ModelFormat: | ||
return ModelFormat.EmbeddingFolder | ||
|
||
def get_base_type(self) -> BaseModelType: | ||
path = self.model_path / "learned_embeds.bin" | ||
if not path.exists(): | ||
class TextualInversionFolderProbe(TextualInversionCheckpointProbe): | ||
def __init__(self, model_path: Path): | ||
files = os.scandir(model_path) | ||
files = [ | ||
Path(f.path) | ||
for f in files | ||
if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors")) | ||
] | ||
if len(files) != 1: | ||
raise InvalidModelConfigException( | ||
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file" | ||
f"Unable to determine base type for {model_path}: expected exactly one valid model file, found {[f.name for f in files]}." | ||
) | ||
return TextualInversionCheckpointProbe(path).get_base_type() | ||
super().__init__(files.pop()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if the first model file we find isn't the correct one? I've seen a few TI models hosted on diffusers that have multiple checkpoints in them, presumably representing different stages of trainign. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nevermind that, I misread the length check. The other issue still stands. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't believe any of the checkpoints are actually used for inference but I could be wrong. This only scans the root directory of the folder so as long as there's only 1 model there it loads it fine. It does move it out of the folder though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I mean is that this logic would appear to fail probing TIs that have multiple model files in the directory. Some look like this (struggling to find an example, filenames probably aren't accurate):
This is a valid TI and the model loader knows how to load it, but I believe the logic in this PR would see the multiple model files and consider it an invalid model. |
||
|
||
|
||
class ONNXFolderProbe(PipelineFolderProbe): | ||
|
@@ -699,17 +704,16 @@ def get_base_type(self) -> BaseModelType: | |
return base_model | ||
|
||
|
||
class LoRAFolderProbe(FolderProbeBase): | ||
def get_base_type(self) -> BaseModelType: | ||
model_file = None | ||
for suffix in ["safetensors", "bin"]: | ||
base_file = self.model_path / f"pytorch_lora_weights.{suffix}" | ||
if base_file.exists(): | ||
model_file = base_file | ||
break | ||
if not model_file: | ||
raise InvalidModelConfigException("Unknown LoRA format encountered") | ||
return LoRACheckpointProbe(model_file).get_base_type() | ||
class LoRAFolderProbe(LoRACheckpointProbe): | ||
def __init__(self, model_path: Path): | ||
files = os.scandir(model_path) | ||
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))] | ||
if len(files) != 1: | ||
raise InvalidModelConfigException( | ||
f"Unable to determine base type for lora {model_path}: expected exactly one valid model file, found {[f.name for f in files]}." | ||
) | ||
model_file = files.pop() | ||
super().__init__(model_file) | ||
|
||
|
||
class IPAdapterFolderProbe(FolderProbeBase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like how this results in a "folder" model having a path that points to a file...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm not really sure what value of storing it as a folder even are for these though. I haven't found any lora or embeddings that have other contextual files that affect generations. Just seems to put us in a position where we have to enforce folder structure for no reason to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The folder was introduced by hugging face. TIs usually contain two files--one with the weights and another with the name of the concept.
For Lora's the folder is used to give the model a name