Skip to content

Commit f5ed98c

Browse files
committed
Add trainer container
1 parent 09cd030 commit f5ed98c

8 files changed

Lines changed: 263 additions & 4 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ og
1313
/src/trainer/outputs
1414
/src/trainer/mlruns
1515

16+
# Docker Training
17+
outputs
18+
1619
# Temp uv setup
1720
.python-version
1821
pyproject.toml

compose.yaml

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,27 @@ services:
2626
command: python -m simulator.diffraction_generator --config /app/configs/simulator.yaml
2727
restart: "no"
2828

29+
trainer:
30+
build:
31+
context: .
32+
dockerfile: docker/trainer.Dockerfile
33+
volumes:
34+
- ./data/:/data/
35+
- ./configs:/configs:ro
36+
- ./outputs:/outputs
37+
- ./src/trainer:/app:ro
38+
environment:
39+
- PYTHONUNBUFFERED=1
40+
- MPLCONFIGDIR=/tmp/matplotlib
41+
- MKL_THREADING_LAYER=GNU
42+
- MKL_SERVICE_FORCE_INTEL=1
43+
- NVIDIA_VISIBLE_DEVICES=all
44+
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
45+
runtime: nvidia
46+
user: "${UID}:${GID}"
47+
command: python /app/run_train_with_manifests.py /configs/trainer.docker.yaml
48+
restart: "no"
49+
2950
ui:
3051
build:
3152
context: .
@@ -38,4 +59,5 @@ services:
3859
environment:
3960
- PYTHONUNBUFFERED=1
4061
- PORT=7860
41-
restart: unless-stopped
62+
restart: unless-stopped
63+

configs/trainer.docker.yaml

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
data:
2+
manifest_dir: "/data/manifests"
3+
dataset_root: "/data/dataset"
4+
auto_generate_manifests: true
5+
train_ratio: 0.8
6+
val_ratio: 0.1
7+
test_ratio: 0.1
8+
seed: 42
9+
10+
loader:
11+
# --- DataLoader ---
12+
batch_size: 64
13+
num_workers: 8
14+
pin_memory: true
15+
persistent_workers: true
16+
prefetch_factor: 2
17+
train_file: "train.jsonl"
18+
val_file: "val.jsonl"
19+
test_file: "test.jsonl"
20+
21+
preprocessing:
22+
validate_paths: false
23+
extract_labels: true
24+
allow_pickle: true
25+
labels_key_map:
26+
x: "dp"
27+
cs: "cs"
28+
sg: "sg"
29+
lattice_params: null
30+
lp_a: "_cell_length_a"
31+
lp_b: "_cell_length_b"
32+
lp_c: "_cell_length_c"
33+
lp_alpha: "_cell_angle_alpha"
34+
lp_beta: "_cell_angle_beta"
35+
lp_gamma: "_cell_angle_gamma"
36+
dtype: "float32"
37+
mmap_mode: null
38+
floor_at_zero: true
39+
normalize_log1p: false
40+
shift_labels: true
41+
42+
augmentation:
43+
noise_poisson_range: [1.0, 100.0]
44+
noise_gaussian_range: [0.001, 0.1]
45+
standardize_to: [0.0, 100.0]
46+
47+
model:
48+
type: "multiscale"
49+
50+
backbone:
51+
dim_in: 8192
52+
dims: [80, 80, 80]
53+
kernel_sizes: [100, 50, 25]
54+
strides: [5, 5, 5]
55+
dropout_rate: 0.3
56+
layer_scale_init_value: 0.0
57+
drop_path_rate: 0.3
58+
ramped_dropout_rate: false
59+
block_type: "convnext"
60+
pooling_type: "average"
61+
final_pool: true
62+
use_batchnorm: false
63+
activation: "leaky_relu"
64+
output_type: "flatten"
65+
66+
heads:
67+
head_dropout: 0.5
68+
cs_hidden: [2300, 1150]
69+
sg_hidden: [2300, 1150]
70+
lp_hidden: [512, 256]
71+
72+
tasks:
73+
num_cs_classes: 7
74+
num_sg_classes: 230
75+
num_lp_outputs: 6
76+
77+
lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
78+
lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0]
79+
bound_lp_with_sigmoid: true
80+
81+
loss:
82+
lambda_cs: 1.0
83+
lambda_sg: 1.0
84+
lambda_lp: 1.0
85+
86+
gemd_mu: 0.0
87+
gemd_distance_matrix_path: null
88+
89+
optimizer:
90+
lr: 0.0002
91+
weight_decay: 0.01
92+
use_adamw: true
93+
gradient_clip_val: 1.0
94+
gradient_clip_algorithm: "norm"
95+
96+
trainer:
97+
default_root_dir: "/outputs/convnext_paper"
98+
max_epochs: 100
99+
accumulate_grad_batches: 1
100+
precision: "32"
101+
accelerator: "gpu"
102+
devices: 1
103+
log_every_n_steps: 200
104+
deterministic: false
105+
benchmark: true
106+
107+
logging:
108+
logger: "mlflow"
109+
csv_logger_name: "model_logs_convnext_paper"
110+
mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt"
111+
mlflow_tracking_uri: "file:/outputs/mlruns"
112+
mlflow_run_name: "ConvNeXt_Paper_Run"
113+
114+
checkpointing:
115+
monitor: "val/loss"
116+
mode: "min"
117+
save_top_k: 1
118+
every_n_epochs: 1
119+
120+
resume_from: null
121+
test_after_train: true
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
data:
2-
manifest_dir: "../../../ad_data/manifests"
3-
dataset_root: "../../../ad_data/data/dataset"
2+
manifest_dir: "data/manifests"
3+
dataset_root: "data/dataset"
44
extra_val_file: "rruff.jsonl"
55
auto_generate_manifests: true
66
train_ratio: 0.8

docker/trainer.Dockerfile

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime
2+
3+
ENV PYTHONDONTWRITEBYTECODE=1
4+
ENV PYTHONUNBUFFERED=1
5+
ENV MPLCONFIGDIR=/tmp/matplotlib
6+
7+
WORKDIR /app
8+
9+
# Install Python dependencies for the trainer
10+
COPY src/trainer/requirements.txt /tmp/requirements.txt
11+
RUN pip install --no-cache-dir -r /tmp/requirements.txt
12+
13+
# Trainer source (can be overridden by bind-mount in compose)
14+
COPY src/trainer /app
15+
16+
ENV PYTHONPATH=/app
17+
18+
CMD ["python", "/app/run_train_with_manifests.py", "/configs/trainer.docker.yaml"]

src/trainer/requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pytorch-lightning>=2.1,<3
2+
PyYAML>=6.0
3+
numpy>=1.24
4+
tqdm>=4.66
5+
matplotlib>=3.7
6+
scikit-learn>=1.3
7+
mlflow>=2.8
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import os
5+
import subprocess
6+
import sys
7+
from typing import Any, Dict
8+
9+
import yaml
10+
11+
from dataset.manifest_utils import generate_manifests
12+
13+
14+
def _parse_args() -> argparse.Namespace:
15+
p = argparse.ArgumentParser(
16+
description="Generate dataset manifests (if needed) then run training."
17+
)
18+
p.add_argument(
19+
"config",
20+
type=str,
21+
help="Path to trainer config YAML (e.g., /configs/trainer.docker.yaml)",
22+
)
23+
p.add_argument(
24+
"--skip-manifests",
25+
action="store_true",
26+
help="Skip manifest generation step.",
27+
)
28+
p.add_argument(
29+
"--only-manifests",
30+
action="store_true",
31+
help="Only generate manifests, then exit.",
32+
)
33+
return p.parse_args()
34+
35+
36+
def _load_config(path: str) -> Dict[str, Any]:
37+
if not os.path.isfile(path):
38+
raise FileNotFoundError(f"Config file not found: {path}")
39+
with open(path, "r", encoding="utf-8") as f:
40+
cfg = yaml.safe_load(f)
41+
if not isinstance(cfg, dict):
42+
raise ValueError(f"Config must be a mapping (YAML dict), got: {type(cfg)}")
43+
return cfg
44+
45+
46+
def _resolve_from_cwd(path: str) -> str:
47+
return path if os.path.isabs(path) else os.path.normpath(os.path.join(os.getcwd(), path))
48+
49+
50+
def _generate_from_config(cfg: Dict[str, Any]) -> None:
51+
if "data" not in cfg:
52+
raise KeyError("Config missing required 'data' section")
53+
54+
data_cfg = cfg["data"]
55+
required = ["dataset_root", "manifest_dir", "train_ratio", "val_ratio", "test_ratio", "seed"]
56+
for key in required:
57+
if key not in data_cfg:
58+
raise KeyError(f"Config data.{key} is required")
59+
60+
dataset_root = _resolve_from_cwd(str(data_cfg["dataset_root"]))
61+
manifest_dir = _resolve_from_cwd(str(data_cfg["manifest_dir"]))
62+
63+
generate_manifests(
64+
dataset_root=dataset_root,
65+
manifest_dir=manifest_dir,
66+
train_ratio=float(data_cfg["train_ratio"]),
67+
val_ratio=float(data_cfg["val_ratio"]),
68+
test_ratio=float(data_cfg["test_ratio"]),
69+
seed=int(data_cfg["seed"]),
70+
)
71+
72+
73+
def main() -> None:
74+
args = _parse_args()
75+
cfg = _load_config(args.config)
76+
77+
if not args.skip_manifests:
78+
_generate_from_config(cfg)
79+
80+
if args.only_manifests:
81+
return
82+
83+
train_path = os.path.join(os.path.dirname(__file__), "train.py")
84+
subprocess.check_call([sys.executable, train_path, args.config])
85+
86+
87+
if __name__ == "__main__":
88+
main()

src/trainer/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule:
8383
val_ratio=data_cfg["val_ratio"],
8484
test_ratio=data_cfg["test_ratio"],
8585
seed=data_cfg["seed"],
86-
extra_val_file=data_cfg["extra_val_file"],
86+
extra_val_file=data_cfg.get("extra_val_file"),
8787
# Optional noise augmentation: apply to training split only
8888
noise_poisson_range=tuple(aug_cfg["noise_poisson_range"]) if aug_cfg["noise_poisson_range"] is not None else None,
8989
noise_gaussian_range=tuple(aug_cfg["noise_gaussian_range"]) if aug_cfg["noise_gaussian_range"] is not None else None,

0 commit comments

Comments
 (0)