|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import argparse |
| 4 | +import os |
| 5 | +import subprocess |
| 6 | +import sys |
| 7 | +from typing import Any, Dict |
| 8 | + |
| 9 | +import yaml |
| 10 | + |
| 11 | +from dataset.manifest_utils import generate_manifests |
| 12 | + |
| 13 | + |
| 14 | +def _parse_args() -> argparse.Namespace: |
| 15 | + p = argparse.ArgumentParser( |
| 16 | + description="Generate dataset manifests (if needed) then run training." |
| 17 | + ) |
| 18 | + p.add_argument( |
| 19 | + "config", |
| 20 | + type=str, |
| 21 | + help="Path to trainer config YAML (e.g., /configs/trainer.docker.yaml)", |
| 22 | + ) |
| 23 | + p.add_argument( |
| 24 | + "--skip-manifests", |
| 25 | + action="store_true", |
| 26 | + help="Skip manifest generation step.", |
| 27 | + ) |
| 28 | + p.add_argument( |
| 29 | + "--only-manifests", |
| 30 | + action="store_true", |
| 31 | + help="Only generate manifests, then exit.", |
| 32 | + ) |
| 33 | + return p.parse_args() |
| 34 | + |
| 35 | + |
| 36 | +def _load_config(path: str) -> Dict[str, Any]: |
| 37 | + if not os.path.isfile(path): |
| 38 | + raise FileNotFoundError(f"Config file not found: {path}") |
| 39 | + with open(path, "r", encoding="utf-8") as f: |
| 40 | + cfg = yaml.safe_load(f) |
| 41 | + if not isinstance(cfg, dict): |
| 42 | + raise ValueError(f"Config must be a mapping (YAML dict), got: {type(cfg)}") |
| 43 | + return cfg |
| 44 | + |
| 45 | + |
| 46 | +def _resolve_from_cwd(path: str) -> str: |
| 47 | + return path if os.path.isabs(path) else os.path.normpath(os.path.join(os.getcwd(), path)) |
| 48 | + |
| 49 | + |
| 50 | +def _generate_from_config(cfg: Dict[str, Any]) -> None: |
| 51 | + if "data" not in cfg: |
| 52 | + raise KeyError("Config missing required 'data' section") |
| 53 | + |
| 54 | + data_cfg = cfg["data"] |
| 55 | + required = ["dataset_root", "manifest_dir", "train_ratio", "val_ratio", "test_ratio", "seed"] |
| 56 | + for key in required: |
| 57 | + if key not in data_cfg: |
| 58 | + raise KeyError(f"Config data.{key} is required") |
| 59 | + |
| 60 | + dataset_root = _resolve_from_cwd(str(data_cfg["dataset_root"])) |
| 61 | + manifest_dir = _resolve_from_cwd(str(data_cfg["manifest_dir"])) |
| 62 | + |
| 63 | + generate_manifests( |
| 64 | + dataset_root=dataset_root, |
| 65 | + manifest_dir=manifest_dir, |
| 66 | + train_ratio=float(data_cfg["train_ratio"]), |
| 67 | + val_ratio=float(data_cfg["val_ratio"]), |
| 68 | + test_ratio=float(data_cfg["test_ratio"]), |
| 69 | + seed=int(data_cfg["seed"]), |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +def main() -> None: |
| 74 | + args = _parse_args() |
| 75 | + cfg = _load_config(args.config) |
| 76 | + |
| 77 | + if not args.skip_manifests: |
| 78 | + _generate_from_config(cfg) |
| 79 | + |
| 80 | + if args.only_manifests: |
| 81 | + return |
| 82 | + |
| 83 | + train_path = os.path.join(os.path.dirname(__file__), "train.py") |
| 84 | + subprocess.check_call([sys.executable, train_path, args.config]) |
| 85 | + |
| 86 | + |
| 87 | +if __name__ == "__main__": |
| 88 | + main() |
0 commit comments