11from typing import List
22
3+ import torch
4+
35from diffusers import FluxTransformer2DModel
4- from diffusers .modular_pipelines import ComponentSpec , InputParam , ModularPipelineBlocks , OutputParam , PipelineState
6+ from diffusers .modular_pipelines import (
7+ ComponentSpec ,
8+ InputParam ,
9+ ModularPipelineBlocks ,
10+ OutputParam ,
11+ PipelineState ,
12+ WanModularPipeline ,
13+ )
14+
15+ from ..testing_utils import nightly , require_torch , slow
516
617
718class DummyCustomBlockSimple (ModularPipelineBlocks ):
@@ -81,10 +92,7 @@ def test_custom_block_supported_components(self):
8192
8293 def test_custom_block_loads_from_hub (self ):
8394 repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
84- block = ModularPipelineBlocks .from_pretrained (
85- repo_id ,
86- trust_remote_code = True ,
87- )
95+ block = ModularPipelineBlocks .from_pretrained (repo_id , trust_remote_code = True )
8896 self ._test_block_properties (block )
8997
9098 pipe = block .init_pipeline ()
@@ -93,3 +101,19 @@ def test_custom_block_loads_from_hub(self):
93101 output = pipe (prompt = prompt )
94102 output_prompt = output .values ["output_prompt" ]
95103 assert output_prompt .startswith ("Modular diffusers + " )
104+
105+
106+ @slow
107+ @nightly
108+ @require_torch
109+ class TestModularCustomBlocksIntegration :
110+ def test_krea_realtime_video_loading (self ):
111+ repo_id = "krea/krea-realtime-video"
112+ blocks = ModularPipelineBlocks .from_pretrained (repo_id , trust_remote_code = True )
113+
114+ pipe = WanModularPipeline (blocks , repo_id )
115+ pipe .load_components (
116+ trust_remote_code = True ,
117+ device_map = "cuda" ,
118+ torch_dtype = {"default" : torch .bfloat16 , "vae" : torch .float16 },
119+ )
0 commit comments