Skip to content

[performance] investigating FluxPipeline for recompilations on resolution changes #11360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sayakpaul opened this issue Apr 18, 2025 · 1 comment
Labels
performance Anything related to performance improvements, profiling and benchmarking torch.compile

Comments

@sayakpaul
Copy link
Member

sayakpaul commented Apr 18, 2025

Similar to #11297, I was investigating potential recompilations for Flux on resolution changes.

Code
from diffusers import FluxTransformer2DModel, FluxPipeline
from diffusers.utils.torch_utils import randn_tensor
import torch.utils.benchmark as benchmark
from contextlib import nullcontext
import argparse

import torch 
torch.fx.experimental._config.use_duck_shape = False

HEIGHT_WIDTH = [(1024, 1024), (1536, 768), (2048, 2048)]

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=1,
    )
    return f"{(t0.blocked_autorange().mean):.3f}"


def prepare_latents(
    batch_size=1,
    num_channels_latents=16,
    height=1024,
    width=1024,
    dtype=torch.bfloat16,
    device="cuda",
):
    vae_scale_factor = 8
    height = 2 * (int(height) // (vae_scale_factor * 2))
    width = 2 * (int(width) // (vae_scale_factor * 2))

    shape = (batch_size, num_channels_latents, height, width)

    latents = randn_tensor(shape, device=device, dtype=dtype)
    latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)

    latent_image_ids = FluxPipeline._prepare_latent_image_ids(
        batch_size, height // 2, width // 2, device, dtype
    )

    return latents, latent_image_ids

def get_conditional_inputs(batch_size, dtype=torch.bfloat16, device="cuda"):
    prompt_embeds = torch.randn(batch_size, 512, 4096, dtype=dtype, device=device)
    pooled_prompt_embeds = torch.randn(batch_size, 768, dtype=dtype, device=device)
    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
    return prompt_embeds, pooled_prompt_embeds, text_ids

def load_transformer(do_compile=False):
    transformer = FluxTransformer2DModel.from_pretrained(
        "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16
    ).to("cuda")
    if do_compile:
        transformer = torch.compile(transformer, fullgraph=True, dynamic=True)
    return transformer

def run_inference(transformer, **kwargs):
    _ = transformer(**kwargs)

@torch.no_grad()
def main(transformer, batch_size, height, width):
    latents, latent_image_ids = prepare_latents(batch_size=batch_size, height=height, width=width)
    prompt_embeds, pooled_prompt_embeds, text_ids = get_conditional_inputs(batch_size=batch_size)
    
    timestep = torch.full([1], 1.0, device="cuda", dtype=torch.float32)
    timestep = timestep.expand(latents.shape[0]).to(latents.dtype)
    timestep = timestep / 1000
    guidance = torch.full([1], 4.5, device="cuda", dtype=torch.float32)
    guidance = guidance.expand(latents.shape[0])

    input_dict = {
        "hidden_states": latents,
        "timestep": timestep,
        "guidance": guidance,
        "pooled_projections": pooled_prompt_embeds,
        "encoder_hidden_states": prompt_embeds,
        "txt_ids": text_ids,
        "img_ids": latent_image_ids
    }

    run_inference(transformer, **input_dict)
    # time = benchmark_fn(run_inference, transformer, **input_dict)
    # print(f"{height}x{width}: {time} secs")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=1, type=int)
    parser.add_argument("--compile", action="store_true")
    args = parser.parse_args()

    transformer = load_transformer(args.compile)
    context = torch._dynamo.config.patch(error_on_recompile=True) if args.compile else nullcontext()
    with context:
        for height, width in HEIGHT_WIDTH:
            main(transformer=transformer, batch_size=args.batch_size, height=height, width=width)

It currently fails when run with python check_flux_recompilation.py --compile:

Trace
Traceback (most recent call last):
  File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 99, in <module>
    main(transformer=transformer, batch_size=args.batch_size, height=height, width=width)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 82, in main
    run_inference(transformer, **input_dict)
  File "/fsx/sayak/diffusers/check_flux_recompilation.py", line 59, in run_inference
    _ = transformer(**kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 675, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1583, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1558, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/__init__.py", line 2365, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2199, in compile_fx
    return aot_autograd(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 106, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1175, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 850, in load
    compiled_fn = dispatch_and_compile()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1160, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 574, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 834, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 240, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1980, in fw_compiler_base
    _recursive_joint_graph_passes(gm)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 402, in _recursive_joint_graph_passes
    joint_graph_passes(gm)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 544, in joint_graph_passes
    GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/fx/passes/graph_transform_observer.py", line 85, in apply_graph_pass
    return pass_fn(self.gm.graph)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 975, in remove_noop_ops
    if same_meta(node, src) and cond(*args, **kwargs):
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 804, in same_meta
    and statically_known_true(sym_eq(val1.size(), val2.size()))
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'SymFloat' object has no attribute 'size'

My env:

- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.8.0.dev20250417+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.52.0.dev0
- Accelerate version: 1.4.0.dev0
- PEFT version: 0.15.2.dev0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@StrongerXi, @anijain2305 would you have any pointers?

@sayakpaul sayakpaul added performance Anything related to performance improvements, profiling and benchmarking torch.compile labels Apr 18, 2025
@anijain2305
Copy link

cc @StrongerXi can you please do the initial debug? cc @eellison

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Anything related to performance improvements, profiling and benchmarking torch.compile
Projects
None yet
Development

No branches or pull requests

2 participants