-
Notifications
You must be signed in to change notification settings - Fork 5.9k
set_adapters not support for compiled model #11408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi, regarding this issue, I tried debugging locally and found that modifying the code in both diffusers and peft can resolve the problem. Here are the specific code changes: index 280a9fa6e..1111a56ab 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -735,6 +735,8 @@ class LoraBaseMixin:
if issubclass(model.__class__, ModelMixin):
model.set_adapters(adapter_names, _component_adapter_weights[component])
+ elif issubclass(model.__class__, torch._dynamo.eval_frame.OptimizedModule):
+ model.set_adapters(adapter_names, _component_adapter_weights[component])
elif issubclass(model.__class__, PreTrainedModel):
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
@@ -832,7 +834,7 @@ class LoraBaseMixin:
model = getattr(self, component, None)
if (
model is not None
- and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
+ and issubclass(model.__class__, (ModelMixin, PreTrainedModel, torch._dynamo.eval_frame.OptimizedModule))
and hasattr(model, "peft_config")
):
set_adapters[component] = list(model.peft_config.keys()) index 4429b0b..897ff14 100644
--- a/src/peft/tuners/lora/layer.py
+++ b/src/peft/tuners/lora/layer.py
@@ -451,7 +451,7 @@ class LoraLayer(BaseTunerLayer):
if adapter not in self.scaling:
# Ignore the case where the adapter is not in the layer
return
- self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
+ self.scaling[adapter] = torch.tensor(scale * self.lora_alpha[adapter] / self.r[adapter])
def scale_layer(self, scale: float) -> None:
if scale == 1: I believe the issue stems from torch.compile converting the model type into torch._dynamo.eval_frame.OptimizedModule, which prevents modification of the LoRA scale. Additionally, the scale needs to be converted to torch.tensor - otherwise, adjusting the LoRA scale would trigger recompilation. However, I'm uncertain whether these modifications might introduce side effects elsewhere. Could you please review this approach or suggest alternative solutions? Thank you. |
Thank you so much for initiating this conversation! Your changes look quite reasonable to me. Cc: @BenjaminBossan here. I think we can try adding a test for your use case here and see if there's any compilation/re-compilation issues? |
Thanks for your reply. The code formatting was messed up in the previous version. I'm re-uploading a clean copy and looking forward to future updates import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
# before load_lora
pipe.enable_lora_hotswap(target_rank=256) # According to your lora
pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_1.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[1.0])
# use torch.compile
pipe.transformer = torch.compile(pipe.transformer)
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")
pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_2.safetensors", adapter_name="lora1", hotswap=True)
pipe.set_adapters(["lora1"], adapter_weights=[0.5])
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev_2.png") |
Thanks for reporting and investigating this @songh11. Regarding the diff for LoRA you proposed above: We had considered this in the past but can't make this change that easily. For instance, a tensor has a device, while normal floats don't, so converting the scale to a tensor would require numerous other changes to ensure that there are no device errors. The solution we came up with is the |
Thank you❤️, and I have another question. The scale is directly passed through the set_adapters function, where the input scale is a floating-point number. What if I pass a torch.tensor instead? Will it work? |
Honestly, I'm not sure, did you give it a try? |
It's fine on my side. If torch.tensor is passed, recompile won't be triggered |
Describe the bug
Thank you for making LORA hotswap currently compatible with torch.compile. However, I'm trying to modify weights via set_adapters when hotswap LORA, but after torch.compile, the model type is no longer ModelMixin, which causes the error. Could you please help look into this issue? Thanks
Reproduction
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
before load_lora
pipe.enable_lora_hotswap(target_rank=256)
pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_1.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[1.0])
use torch.compile
pipe.transformer = torch.compile(pipe.transformer)
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")
pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_2.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[0.5])
Logs
System Info
ᐅ diffusers-cli env
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
Who can help?
@sayakpaul
The text was updated successfully, but these errors were encountered: