diff --git a/src/diffusers/modular_pipelines/ernie_image/encoders.py b/src/diffusers/modular_pipelines/ernie_image/encoders.py index 24e9622c9422..161646d181be 100644 --- a/src/diffusers/modular_pipelines/ernie_image/encoders.py +++ b/src/diffusers/modular_pipelines/ernie_image/encoders.py @@ -15,16 +15,23 @@ import json import torch -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Mistral3Model from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance from ...utils import logging +from ...utils.import_utils import is_transformers_version from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import ErnieImageModularPipeline +if is_transformers_version("<", "5.0.0"): + raise ImportError("`ErnieImageModularPipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`.") + +from transformers import Ministral3ForCausalLM # noqa: E402 + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -38,7 +45,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("pe", AutoModelForCausalLM), + ComponentSpec("pe", Ministral3ForCausalLM), ComponentSpec("pe_tokenizer", AutoTokenizer), ] @@ -83,7 +90,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _enhance_prompt( - pe: AutoModelForCausalLM, + pe: Ministral3ForCausalLM, pe_tokenizer: AutoTokenizer, prompt: str, device: torch.device, @@ -160,7 +167,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("text_encoder", AutoModel), + ComponentSpec("text_encoder", Mistral3Model), ComponentSpec("tokenizer", AutoTokenizer), ComponentSpec( "guider", @@ -200,7 +207,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _encode( - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, prompt: list[str], device: torch.device, diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index e0231c4620c5..11fce6a204bf 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -20,7 +20,7 @@ from typing import Callable, List, Optional, Union import torch -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Mistral3Model from ...image_processor import VaeImageProcessor from ...loaders import ErnieImageLoraLoaderMixin @@ -28,10 +28,17 @@ from ...models.transformers import ErnieImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.import_utils import is_transformers_version from ...utils.torch_utils import randn_tensor from .pipeline_output import ErnieImagePipelineOutput +if is_transformers_version("<", "5.0.0"): + raise ImportError("`ErnieImagePipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`.") + +from transformers import Ministral3ForCausalLM # noqa: E402 + + class ErnieImagePipeline(DiffusionPipeline, ErnieImageLoraLoaderMixin): """ Pipeline for text-to-image generation using ErnieImageTransformer2DModel. @@ -52,10 +59,10 @@ def __init__( self, transformer: ErnieImageTransformer2DModel, vae: AutoencoderKLFlux2, - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, scheduler: FlowMatchEulerDiscreteScheduler, - pe: Optional[AutoModelForCausalLM] = None, + pe: Optional[Ministral3ForCausalLM] = None, pe_tokenizer: Optional[AutoTokenizer] = None, ): super().__init__()