forked from NVIDIA-NeMo/Curator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
115 lines (93 loc) · 4.06 KB
/
main.py
File metadata and controls
115 lines (93 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ALM (Audio Language Model) Data Pipeline for NeMo Curator.
This script processes audio manifests to create training windows for
Audio Language Models using YAML-based configuration with Hydra.
The pipeline starts with ALMManifestReader (CompositeStage that decomposes into
FilePartitioningStage + ALMManifestReaderStage for line-by-line JSONL reading
on workers) followed by configurable processing stages.
Features:
- YAML-driven pipeline configuration using nemo_curator.config.run
- Command-line parameter overrides
- Extensible stage chain
- Manifest I/O on workers via EmptyTask pattern
Usage:
# Run with sample data (from Curator repo root)
python tutorials/audio/alm/main.py \\
--config-path . \\
--config-name pipeline \\
manifest_path=tests/fixtures/audio/alm/sample_input.jsonl
# Override parameters
python tutorials/audio/alm/main.py \\
--config-path . \\
--config-name pipeline \\
manifest_path=/data/input.jsonl \\
output_dir=./my_output \\
stages.1.min_speakers=3 \\
stages.2.overlap_percentage=30
"""
import importlib
import hydra
from loguru import logger
from omegaconf import DictConfig
from nemo_curator.config.run import create_pipeline_from_yaml
from nemo_curator.tasks.utils import TaskPerfUtils
_EXECUTOR_FACTORIES = {
"xenna": "nemo_curator.backends.xenna:XennaExecutor",
"ray_data": "nemo_curator.backends.ray_data:RayDataExecutor",
}
def _create_executor(backend: str) -> object:
module_path, class_name = _EXECUTOR_FACTORIES[backend].rsplit(":", 1)
mod = importlib.import_module(module_path)
return getattr(mod, class_name)()
@hydra.main(version_base=None)
def main(cfg: DictConfig) -> None:
"""Run ALM pipeline using Hydra configuration."""
pipeline = create_pipeline_from_yaml(cfg)
logger.info(pipeline.describe())
logger.info("\n" + "=" * 50 + "\n")
backend = cfg.get("backend", "xenna")
if backend not in _EXECUTOR_FACTORIES:
msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}"
raise ValueError(msg)
logger.info(f"Using backend: {backend}")
executor = _create_executor(backend)
logger.info("Starting pipeline execution...")
results = pipeline.run(executor)
output_files = []
for task in results or []:
output_files.extend(task.data)
unique_files = sorted(set(output_files))
logger.info("\n" + "=" * 50)
logger.info("PIPELINE COMPLETE")
logger.info("=" * 50)
logger.info(f" Output files written: {len(unique_files)}")
for fp in unique_files:
logger.info(f" - {fp}")
stage_metrics = TaskPerfUtils.collect_stage_metrics(results)
for stage_name, metrics in stage_metrics.items():
logger.info(f" [{stage_name}]")
logger.info(
f" process_time: mean={metrics['process_time'].mean():.4f}s, total={metrics['process_time'].sum():.2f}s"
)
logger.info(f" items_processed: {metrics['num_items_processed'].sum():.0f}")
if "custom.windows_created" in metrics:
logger.info(f" windows_created: {metrics['custom.windows_created'].sum():.0f}")
if "custom.output_windows" in metrics:
logger.info(f" output_windows (after overlap): {metrics['custom.output_windows'].sum():.0f}")
if "custom.filtered_dur" in metrics:
logger.info(f" filtered_audio_duration: {metrics['custom.filtered_dur'].sum():.1f}s")
if __name__ == "__main__":
main()