Skip to content

extend AutoModel to be able to load transformer models & custom diffusers models #11388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yiyixuxu opened this issue Apr 22, 2025 · 3 comments · May be fixed by #11401
Open

extend AutoModel to be able to load transformer models & custom diffusers models #11388

yiyixuxu opened this issue Apr 22, 2025 · 3 comments · May be fixed by #11401

Comments

@yiyixuxu
Copy link
Collaborator

current AutoModel implementation only support models that importable from diffusers

model_cls = getattr(library, orig_class_name, None)

it uses this field in model config.json https://huggingface.co/HiDream-ai/HiDream-I1-Full/blob/main/transformer/config.json#L2

we can use info in model_index.json instead, https://huggingface.co/HiDream-ai/HiDream-I1-Full/blob/main/model_index.json
this way we can load anything that can be load in from_pretrained, include transformer models & custom diffusers moduels (not directly importable from top level)

can reference the code/logic in from_pretrained

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

I will work on this shortly, but feel free to pick this up if anyone in the community is interested!

@afafelwafi
Copy link

Hi @yiyixuxu, I’d love to contribute to this enhancement if you haven’t started already!

Just to clarify before I dive in:

The current AutoModel implementation only supports models that are importable directly from the diffusers top-level namespace (e.g., via _class_name in config.json).

For models like HiDream-I1-Full, which define components only in model_index.json, this approach breaks because the model class isn't found viagetattr(diffusers, class_name)

Here's what I’m proposing at the AutoModel.from_pretrained level:

    try:
            config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
        except EnvironmentError as e:
            # fallback to model_index.json
            try:
                from ..configuration_utils import ConfigMixin

                backup_cls = type(
                    "AutoModelIndexFallback",
                    (ConfigMixin,),
                    {"config_name": "model_index.json"},
                )
                config = backup_cls.load_config(
                    pretrained_model_or_path, **load_config_kwargs
                )
            except Exception as fallback_error:
                raise EnvironmentError(
                    f"Failed to load both config.json and model_index.json for '{pretrained_model_or_path}'.\n"
                    f"Original error: {e}\nFallback error: {fallback_error}"
                )

        orig_class_name = config.get("_class_name", None)
        model_cls = None

        if orig_class_name:
            try:
                library = importlib.import_module("diffusers")
                model_cls = getattr(library, orig_class_name, None)
            except Exception:
                model_cls = None

        if model_cls is None:
            try:
                fetched = _fetch_class_library_tuple(config)
                if fetched:
                    library_name, class_name = fetched
                    imported_lib = importlib.import_module(library_name)
                    model_cls = getattr(imported_lib, class_name, None)
            except Exception as e:
                raise ImportError(
                    f"AutoModel fallback failed to import model class: {e}"
                )

        if model_cls is None:
            raise ValueError(
                "AutoModel could not resolve the model class. Check `_class_name` or model_index.json."
            )

This would allow AutoModel to load both standard and custom model components defined in model_index.json, even if they’re not importable from diffusers directly.

Let me know if you think this is the right approach or if you had a different level in mind (e.g., deeper integration into the loader logic in pipeline_utils.py)!

@yiyixuxu
Copy link
Collaborator Author

hi @afafelwafi
thanks for your interest on this issue! your overall proposal sounds good to me: to try the current approach first and fall back on an alternative approach to use model_index if it does not work

a few references might be helpful to you using model_index.json

  1. download only the model index (see some reference code here https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/single_file.py#L242)
  2. import the model class (
    def get_class_obj_and_candidates(
    ); we don't have to work with custom code here just yet so can simplify a bit maybe. The main logic is
        from diffusers import pipelines
        is_pipeline_module = hasattr(pipelines, library_name)
        if is_pipeline_module:       
           pipeline_module = getattr(pipelines, library_name)
           class_obj = getattr(pipeline_module, class_name)
       else:
           library = importlib.import_module(library_name)
           class_obj = getattr(library, class_name)
  1. testing script for use cases we want to support: currently we fail on 2 and 3
from diffusers import AutoModel 
import torch


# test1: load a diffusers model
try:
    model = AutoModel.from_pretrained(
        "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16, 
    )
    print(f"test1 passed!")
except Exception as e:
    print(f"test1 failed: {e}")


# test2: load a non-diffusers model
try:
    model = AutoModel.from_pretrained(
        "HiDream-ai/HiDream-I1-Full", subfolder="text_encoder", torch_dtype=torch.bfloat16, 
    )
    print(f"test2 passed!")
except Exception as e:
    print(f"test2 failed: {e}")

# test3: load a custom diffusers model 
# https://huggingface.co/Kwai-Kolors/Kolors-diffusers/blob/main/model_index.json#L9
try:
    model = AutoModel.from_pretrained("Kwai-Kolors/Kolors-diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16)
    print(f"test3 passed!")
except Exception as e:
    print(f"test3 failed: {e}")

# test4: load a model directly (not subfolder)
controlnet_repo = "InstantX/SD3-Controlnet-Canny"
try:
    controlnet_model = AutoModel.from_pretrained(
        controlnet_repo, revision="refs/pr/12"
    )
    print(f"test4 passed!")
except Exception as e:
    print(f"test4 failed: {e}")

@yiyixuxu
Copy link
Collaborator Author

ohh @afafelwafi sorry,
I just saw a PR in already

thanks for your interest though! we will keep posting issues like this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging a pull request may close this issue.

2 participants