Skip to content

set_adapters not support for compiled model #11408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
songh11 opened this issue Apr 24, 2025 · 7 comments
Open

set_adapters not support for compiled model #11408

songh11 opened this issue Apr 24, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@songh11
Copy link

songh11 commented Apr 24, 2025

Describe the bug

Thank you for making LORA hotswap currently compatible with torch.compile. However, I'm trying to modify weights via set_adapters when hotswap LORA, but after torch.compile, the model type is no longer ModelMixin, which causes the error. Could you please help look into this issue? Thanks

Reproduction

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')

before load_lora

pipe.enable_lora_hotswap(target_rank=256)

pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_1.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[1.0])

use torch.compile

pipe.transformer = torch.compile(pipe.transformer)

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")

pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_2.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[0.5])

Logs

Traceback (most recent call last):
  File "/root/workspace/aigoodsfix/plugins/aigoodsfix/api.py", line 616, in <module>
    Init("")
  File "/root/workspace/aigoodsfix/plugins/aigoodsfix/api.py", line 136, in Init
    pipe.set_adapters(["lora1"], adapter_weights=[0.5])
  File "/root/workspace/diffusers/src/diffusers/loaders/lora_base.py", line 703, in set_adapters
    raise ValueError(
ValueError: Adapter name(s) {'lora1'} not in the list of present adapters: set().

System Info

ᐅ diffusers-cli env

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.34.0.dev0
  • Platform: Linux-5.4.250-2-velinux1u3-amd64-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.30.2
  • Transformers version: 4.47.0
  • Accelerate version: 1.2.0
  • PEFT version: 0.15.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA L20, 49140 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul

@songh11 songh11 added the bug Something isn't working label Apr 24, 2025
@songh11
Copy link
Author

songh11 commented Apr 24, 2025

Hi, regarding this issue, I tried debugging locally and found that modifying the code in both diffusers and peft can resolve the problem. Here are the specific code changes:

index 280a9fa6e..1111a56ab 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -735,6 +735,8 @@ class LoraBaseMixin:
 
             if issubclass(model.__class__, ModelMixin):
                 model.set_adapters(adapter_names, _component_adapter_weights[component])
+            elif issubclass(model.__class__, torch._dynamo.eval_frame.OptimizedModule):
+                model.set_adapters(adapter_names, _component_adapter_weights[component])
             elif issubclass(model.__class__, PreTrainedModel):
                 set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
 
@@ -832,7 +834,7 @@ class LoraBaseMixin:
             model = getattr(self, component, None)
             if (
                 model is not None
-                and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
+                and issubclass(model.__class__, (ModelMixin, PreTrainedModel, torch._dynamo.eval_frame.OptimizedModule))
                 and hasattr(model, "peft_config")
             ):
                 set_adapters[component] = list(model.peft_config.keys())
index 4429b0b..897ff14 100644
--- a/src/peft/tuners/lora/layer.py
+++ b/src/peft/tuners/lora/layer.py
@@ -451,7 +451,7 @@ class LoraLayer(BaseTunerLayer):
         if adapter not in self.scaling:
             # Ignore the case where the adapter is not in the layer
             return
-        self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
+        self.scaling[adapter] = torch.tensor(scale * self.lora_alpha[adapter] / self.r[adapter])
 
     def scale_layer(self, scale: float) -> None:
         if scale == 1:

I believe the issue stems from torch.compile converting the model type into torch._dynamo.eval_frame.OptimizedModule, which prevents modification of the LoRA scale. Additionally, the scale needs to be converted to torch.tensor - otherwise, adjusting the LoRA scale would trigger recompilation. However, I'm uncertain whether these modifications might introduce side effects elsewhere. Could you please review this approach or suggest alternative solutions? Thank you.

@sayakpaul
Copy link
Member

Thank you so much for initiating this conversation! Your changes look quite reasonable to me. Cc: @BenjaminBossan here.

I think we can try adding a test for your use case here and see if there's any compilation/re-compilation issues?

@songh11
Copy link
Author

songh11 commented Apr 24, 2025

Thank you so much for initiating this conversation! Your changes look quite reasonable to me. Cc: @BenjaminBossan here.

I think we can try adding a test for your use case here and see if there's any compilation/re-compilation issues?

Thanks for your reply. The code formatting was messed up in the previous version. I'm re-uploading a clean copy and looking forward to future updates

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')

# before load_lora
pipe.enable_lora_hotswap(target_rank=256) # According to your lora

pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_1.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[1.0])

# use torch.compile
pipe.transformer = torch.compile(pipe.transformer)

prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]
image.save("flux-dev.png")

pipe.load_lora_weights("/lora_path", weight_name="pytorch_lora_weights_2.safetensors", adapter_name="lora1", hotswap=True)
pipe.set_adapters(["lora1"], adapter_weights=[0.5])

image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]
image.save("flux-dev_2.png")

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Apr 24, 2025

Thanks for reporting and investigating this @songh11. Regarding the diff for LoRA you proposed above: We had considered this in the past but can't make this change that easily. For instance, a tensor has a device, while normal floats don't, so converting the scale to a tensor would require numerous other changes to ensure that there are no device errors.

The solution we came up with is the prepare_model_for_compiled_hotswap function in PEFT, which is triggered by enable_lora_hotswap in diffusers. This function takes care of converting the scales to tensors, and as it's opt-in, we don't run into the trouble of breaking existing code. I'm not quite sure yet why in your case, the scalings are still floats, but I think a similar approach as just explained should work.

@songh11
Copy link
Author

songh11 commented Apr 25, 2025

Thanks for reporting and investigating this @songh11. Regarding the diff for LoRA you proposed above: We had considered this in the past but can't make this change that easily. For instance, a tensor has a device, while normal floats don't, so converting the scale to a tensor would require numerous other changes to ensure that there are no device errors.

The solution we came up with is the prepare_model_for_compiled_hotswap function in PEFT, which is triggered by enable_lora_hotswap in diffusers. This function takes care of converting the scales to tensors, and is it's opt-in, we don't run into the trouble of breaking existing code. I'm not quite sure yet why in your case, the scalings are still floats, but I think a similar approach as just explained should work.

Thank you❤️, and I have another question. The scale is directly passed through the set_adapters function, where the input scale is a floating-point number. What if I pass a torch.tensor instead? Will it work?

@BenjaminBossan
Copy link
Member

The scale is directly passed through the set_adapters function, where the input scale is a floating-point number. What if I pass a torch.tensor instead? Will it work?

Honestly, I'm not sure, did you give it a try?

@songh11
Copy link
Author

songh11 commented Apr 29, 2025

The scale is directly passed through the set_adapters function, where the input scale is a floating-point number. What if I pass a torch.tensor instead? Will it work?

Honestly, I'm not sure, did you give it a try?

It's fine on my side. If torch.tensor is passed, recompile won't be triggered

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants