Skip to content

Commit 47baa47

Browse files
committed
move node registry to mellon
1 parent 55d49d4 commit 47baa47

3 files changed

Lines changed: 10 additions & 258 deletions

File tree

src/diffusers/modular_pipelines/mellon_node_utils.py

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,16 @@
217217
"display": "output",
218218
"type": "controlnet",
219219
},
220+
"doc": {
221+
"label": "Doc",
222+
"display": "output",
223+
"type": "string",
224+
},
225+
"latents_preview": {
226+
"label": "Latents Preview",
227+
"display": "output",
228+
"type": "latent",
229+
},
220230
}
221231

222232

@@ -697,67 +707,3 @@ def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNo
697707
blocks_names=blocks_names,
698708
node_type=node_type,
699709
)
700-
701-
702-
# Minimal modular registry for Mellon node configs
703-
class ModularMellonNodeRegistry:
704-
"""Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig."""
705-
706-
def __init__(self):
707-
self._registry = {}
708-
self._initialized = False
709-
710-
def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]):
711-
if not self._initialized:
712-
_initialize_registry(self)
713-
self._registry[pipeline_cls] = node_params
714-
715-
def get(self, pipeline_cls: type) -> MellonNodeConfig:
716-
if not self._initialized:
717-
_initialize_registry(self)
718-
return self._registry.get(pipeline_cls, None)
719-
720-
def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]:
721-
if not self._initialized:
722-
_initialize_registry(self)
723-
return self._registry
724-
725-
726-
def _register_preset_node_types(
727-
pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry
728-
):
729-
"""Register all node-type presets for a given pipeline class from a params map."""
730-
node_configs = {}
731-
for node_type, spec in params_map.items():
732-
node_config = MellonNodeConfig(
733-
inputs=spec.get("inputs", []),
734-
model_inputs=spec.get("model_inputs", []),
735-
outputs=spec.get("outputs", []),
736-
blocks_names=spec.get("block_names", []),
737-
node_type=node_type,
738-
)
739-
node_configs[node_type] = node_config
740-
registry.register(pipeline_cls, node_configs)
741-
742-
743-
def _initialize_registry(registry: ModularMellonNodeRegistry):
744-
"""Initialize the registry and register all available pipeline configs."""
745-
print("Initializing registry")
746-
747-
registry._initialized = True
748-
749-
try:
750-
from .qwenimage.modular_pipeline import QwenImageModularPipeline
751-
from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP
752-
753-
_register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry)
754-
except Exception:
755-
raise Exception("Failed to register QwenImageModularPipeline")
756-
757-
try:
758-
from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
759-
from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP
760-
761-
_register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry)
762-
except Exception:
763-
raise Exception("Failed to register StableDiffusionXLModularPipeline")

src/diffusers/modular_pipelines/qwenimage/node_utils.py

Lines changed: 0 additions & 95 deletions
This file was deleted.

src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

0 commit comments

Comments
 (0)