SDXL Serverless Worker: How to Cache LoRA models
In this code from github sdxl serverless worker repo, how does I cache LoRA models and get there path to use in my handler function?
# builder/model_fetcher.py
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL
def fetch_pretrained_model(model_class, model_name, **kwargs):
'''
Fetches a pretrained model from the HuggingFace model hub.
'''
max_retries = 3
for attempt in range(max_retries):
try:
return model_class.from_pretrained(model_name, **kwargs)
except OSError as err:
if attempt < max_retries - 1:
print(
f"Error encountered: {err}. Retrying attempt {attempt + 1} of {max_retries}...")
else:
raise
def get_diffusion_pipelines():
'''
Fetches the Stable Diffusion XL pipelines from the HuggingFace model hub.
'''
common_args = {
"torch_dtype": torch.float16,
"variant": "fp16",
"use_safetensors": True
}
pipe = fetch_pretrained_model(StableDiffusionXLPipeline,
"stabilityai/stable-diffusion-xl-base-1.0", **common_args)
vae = fetch_pretrained_model(
AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", **{"torch_dtype": torch.float16}
)
print("Loaded VAE")
refiner = fetch_pretrained_model(StableDiffusionXLImg2ImgPipeline,
"stabilityai/stable-diffusion-xl-refiner-1.0", **common_args)
return pipe, refiner, vae
if __name__ == "__main__":
get_diffusion_pipelines()
4 Replies
It doesn't support LoRA
You need to use something like A1111 or ComfyUI if you want to use LoRA
in theory it's supported https://huggingface.co/docs/diffusers/v0.13.0/en/training/lora
Yeah but has to be implemented by themselves, there is no RunPod support for implementing it.
yup now that all workers are self managed