Skip to content

Commit

Permalink
add simplified model manager install API to InvocationContext
Browse files Browse the repository at this point in the history
  • Loading branch information
Lincoln Stein authored and lstein committed Apr 12, 2024
1 parent 24f2cde commit af1b57a
Showing 1 changed file with 97 additions and 1 deletion.
98 changes: 97 additions & 1 deletion invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor

from invokeai.app.invocations.constants import IMAGE_MODES
Expand Down Expand Up @@ -426,6 +427,101 @@ def search_by_attrs(
model_format=format,
)

def install_model(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
timeout: Optional[int] = 0,
) -> str:
"""Install and register a model in the database.
Args:
source: String source; see below
config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record.
access_token: Optional access token for remote sources.
inplace: If true, installs a local model in place rather than copying
it into the models directory
timeout: How long to wait on install (in seconds). A value of 0 (default)
blocks indefinitely
The source can be:
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
2. An http or https URL (`https://foo.bar/foo`)
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
We extend the HuggingFace repo_id syntax to include the variant and the
subfolder or path. The following are acceptable alternatives:
stabilityai/stable-diffusion-v4
stabilityai/stable-diffusion-v4:fp16
stabilityai/stable-diffusion-v4:fp16:vae
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
stabilityai/stable-diffusion-v4:onnx:vae
Because a local file path can look like a huggingface repo_id, the logic
first checks whether the path exists on disk, and if not, it is treated as
a parseable huggingface repo.
Returns:
Key to the newly installed model.
May Raise:
ValueError -- bad source
UnknownModelException -- remote model not found
InvalidModelException -- what was retrieved from remote is not a model
TimeoutError -- model could not be installed within timeout
Exception -- another error condition
"""
installer = self._services.model_manager.install
job = installer.heuristic_import(
source=source,
config=config,
access_token=access_token,
inplace=inplace,
)
installer.wait_for_job(job, timeout)
if job.errored:
raise Exception(job.error)
key: str = job.config_out.key
return key

def download_and_cache_model(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
) -> Path:
"""Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
Result:
Path of the downloaded model
May Raise:
HTTPError
TimeoutError
"""
installer = self._services.model_manager.install
path: Path = installer.download_and_cache(
source=source,
access_token=access_token,
timeout=timeout,
)
return path


class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
Expand Down

0 comments on commit af1b57a

Please sign in to comment.