Skip to content

Commit 7d15173

Browse files
committed
refactor: Move to nested config structure for organization
1 parent 536b58d commit 7d15173

5 files changed

Lines changed: 296 additions & 270 deletions

File tree

configs/trainer.yaml

Lines changed: 108 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,124 @@
11
# AlphaDiffract trainer configuration — ConvNeXt (paper-matching lightweight variant)
22
# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml
33

4-
# --- Data / Manifests ---
5-
manifest_dir: "../../../ad_data/manifests"
6-
dataset_root: "../../../ad_data/data/dataset"
7-
extra_val_file: "rruff.jsonl"
8-
auto_generate_manifests: true
9-
train_ratio: 0.8
10-
val_ratio: 0.1
11-
test_ratio: 0.1
12-
seed: 42
4+
data:
5+
manifest_dir: "../../../ad_data/manifests"
6+
dataset_root: "../../../ad_data/data/dataset"
7+
extra_val_file: "rruff.jsonl"
8+
auto_generate_manifests: true
9+
train_ratio: 0.8
10+
val_ratio: 0.1
11+
test_ratio: 0.1
12+
seed: 42
1313

14-
# --- DataLoader ---
15-
batch_size: 64 # match OG run (64 per process)
16-
num_workers: 8
17-
pin_memory: true
18-
persistent_workers: true
14+
loader:
15+
# --- DataLoader ---
16+
batch_size: 64 # match OG run (64 per process)
17+
num_workers: 8
18+
pin_memory: true
19+
persistent_workers: true
20+
prefetch_factor: 2
21+
train_file: "train.jsonl"
22+
val_file: "val.jsonl"
23+
test_file: "test.jsonl"
1924

20-
# --- Dataset label extraction (embedded in .npy/.npz) ---
21-
validate_paths: false
22-
extract_labels: true
23-
allow_pickle: true
24-
labels_key_map:
25-
x: "dp"
26-
cs: "cs"
27-
sg: "sg"
28-
lattice_params: null
29-
lp_a: "_cell_length_a"
30-
lp_b: "_cell_length_b"
31-
lp_c: "_cell_length_c"
32-
lp_alpha: "_cell_angle_alpha"
33-
lp_beta: "_cell_angle_beta"
34-
lp_gamma: "_cell_angle_gamma"
35-
dtype: "float32"
36-
mmap_mode: null
37-
floor_at_zero: true
38-
normalize_log1p: False # paper used log1p preprocessing
39-
model_type: "multiscale"
25+
preprocessing:
26+
validate_paths: false
27+
extract_labels: true
28+
allow_pickle: true
29+
labels_key_map:
30+
x: "dp"
31+
cs: "cs"
32+
sg: "sg"
33+
lattice_params: null
34+
lp_a: "_cell_length_a"
35+
lp_b: "_cell_length_b"
36+
lp_c: "_cell_length_c"
37+
lp_alpha: "_cell_angle_alpha"
38+
lp_beta: "_cell_angle_beta"
39+
lp_gamma: "_cell_angle_gamma"
40+
dtype: "float32"
41+
mmap_mode: null
42+
floor_at_zero: true
43+
normalize_log1p: False # paper used log1p preprocessing
4044

41-
# --- ConvNeXt (OG-equivalent configuration) ---
42-
# 3 stages; one block per stage; large kernels; stride-5 downsampling
43-
# Matches OG multiscale_cnn_cls_regr_convnextBlock_LeakyReLU.json exactly
44-
depths: [1, 1, 1]
45-
dims: [80, 80, 80]
46-
kernel_sizes: [100, 50, 25]
47-
strides: [5, 5, 5]
48-
dropout_rate: 0.3
49-
# OG uses layer_scale_init_value=0 (disabled)
50-
layer_scale_init_value: 0.0
51-
# OG uses constant drop_path_rate=0.3 (not ramped)
52-
drop_path_rate: 0.3
53-
ramped_dropout_rate: false
54-
block_type: "convnext"
55-
pooling_type: "average"
56-
final_pool: true
57-
use_batchnorm: false
58-
output_type: "flatten"
45+
augmentation:
46+
noise_poisson_range: [1.0, 100.0]
47+
noise_gaussian_range: [0.001, 0.1]
48+
standardize_to: [0.0, 100.0]
5949

60-
# Heads
61-
head_dropout: 0.5
62-
cs_hidden: [2300, 1150]
63-
sg_hidden: [2300, 1150]
64-
lp_hidden: [512, 256]
50+
model:
51+
type: "multiscale"
52+
53+
backbone:
54+
dim_in: 8192
55+
dims: [80, 80, 80]
56+
kernel_sizes: [100, 50, 25]
57+
strides: [5, 5, 5]
58+
dropout_rate: 0.3
59+
layer_scale_init_value: 0.0
60+
drop_path_rate: 0.3
61+
ramped_dropout_rate: false
62+
block_type: "convnext"
63+
pooling_type: "average"
64+
final_pool: true
65+
use_batchnorm: false
66+
activation: "leaky_relu"
67+
output_type: "flatten"
6568

66-
# Task sizes
67-
num_cs_classes: 7
68-
num_sg_classes: 230
69-
num_lp_outputs: 6
69+
heads:
70+
head_dropout: 0.5
71+
cs_hidden: [2300, 1150]
72+
sg_hidden: [2300, 1150]
73+
lp_hidden: [512, 256]
7074

71-
# LP output bounds
72-
lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
73-
lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0]
74-
bound_lp_with_sigmoid: true
75+
tasks:
76+
num_cs_classes: 7
77+
num_sg_classes: 230
78+
num_lp_outputs: 6
7579

76-
# Loss weights
77-
lambda_cs: 1.0
78-
lambda_sg: 1.0
79-
lambda_lp: 1.0
80+
lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
81+
lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0]
82+
bound_lp_with_sigmoid: true
8083

81-
# Optional GEMD term on SG
82-
gemd_mu: 0.0
83-
gemd_distance_matrix_path: null
84+
loss:
85+
lambda_cs: 1.0
86+
lambda_sg: 1.0
87+
lambda_lp: 1.0
8488

85-
# Optimizer (paper): AdamW, lr=2e-4, wd=0.01
86-
lr: 0.0002
87-
weight_decay: 0.01
88-
use_adamw: true
89-
gradient_clip_val: 1.0
90-
gradient_clip_algorithm: "norm"
89+
gemd_mu: 0.0
90+
gemd_distance_matrix_path: null
9191

92-
# --- Noise augmentation (training split only; matches paper) ---
93-
# If provided, noise is applied dynamically per-sample in the DataModule using the same
94-
# sequencing as the paper: Poisson -> normalize -> add Gaussian -> renormalize -> rescale.
95-
# Set ranges to None to disable.
96-
noise_poisson_range: [1.0, 100.0]
97-
noise_gaussian_range: [0.001, 0.1]
92+
optimizer:
93+
lr: 0.0002
94+
weight_decay: 0.01
95+
use_adamw: true
96+
gradient_clip_val: 1.0
97+
gradient_clip_algorithm: "norm"
9898

99-
# Standardize after noise to match OG CLI (--standardize-to 0 100)
100-
standardize_to: [0.0, 100.0]
101-
# --- Logging ---
102-
logger: "mlflow"
103-
csv_logger_name: "model_logs_convnext_paper"
104-
mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt"
105-
mlflow_tracking_uri: null
106-
mlflow_run_name: "ConvNeXt_Paper_Run"
99+
trainer:
100+
default_root_dir: "outputs/convnext_paper"
101+
max_epochs: 100
102+
accumulate_grad_batches: 1
103+
precision: "32" # match OG (AMP disabled)
104+
accelerator: "gpu"
105+
devices: 1
106+
log_every_n_steps: 200
107+
deterministic: false
108+
benchmark: true
107109

108-
# --- Trainer settings ---
109-
default_root_dir: "outputs/convnext_paper"
110-
max_epochs: 100
111-
accumulate_grad_batches: 1
112-
precision: "32" # match OG (AMP disabled)
113-
accelerator: "gpu"
114-
devices: 1
115-
log_every_n_steps: 200
116-
deterministic: false
117-
benchmark: true
110+
logging:
111+
logger: "mlflow"
112+
csv_logger_name: "model_logs_convnext_paper"
113+
mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt"
114+
mlflow_tracking_uri: null
115+
mlflow_run_name: "ConvNeXt_Paper_Run"
118116

119-
# --- Checkpointing ---
120-
monitor: "val/loss"
121-
mode: "min"
122-
save_top_k: 1
123-
every_n_epochs: 1
124-
125-
# --- Evaluation ---
126-
resume_from:
127-
test_after_train: true
117+
checkpointing:
118+
monitor: "val/loss"
119+
mode: "min"
120+
save_top_k: 1
121+
every_n_epochs: 1
122+
123+
resume_from: null
124+
test_after_train: true

src/trainer/dataset/datamodule.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,15 @@ class NpyDataModule(pl.LightningDataModule):
130130

131131
def __init__(
132132
self,
133-
manifest_dir: str = "data/manifests",
134-
batch_size: int = 32,
135-
num_workers: int = 4,
136-
pin_memory: bool = True,
137-
persistent_workers: bool = True,
138-
prefetch_factor: Optional[int] = None,
133+
manifest_dir: str,
134+
batch_size: int,
135+
num_workers: int,
136+
pin_memory: bool,
137+
persistent_workers: bool,
138+
prefetch_factor: Optional[int],
139+
train_file: str,
140+
val_file: str,
141+
test_file: str,
139142
collate_fn: Optional[Callable] = None,
140143
dataset_cls: type = NpyManifestDataset,
141144
dataset_kwargs: Optional[Dict[str, Any]] = None,
@@ -146,10 +149,6 @@ def __init__(
146149
val_ratio: float = 0.1,
147150
test_ratio: float = 0.1,
148151
seed: int = 42,
149-
# Custom manifest filenames (within manifest_dir)
150-
train_file: str = "train.jsonl",
151-
val_file: str = "val.jsonl",
152-
test_file: str = "test.jsonl",
153152
# Optional: add a second validation manifest file (e.g., "rruff.jsonl")
154153
extra_val_file: Optional[str] = None,
155154
# Optional noise augmentation for training split only

src/trainer/dataset/dataset.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,16 @@ class NpyManifestDataset(Dataset):
149149
def __init__(
150150
self,
151151
manifest_path: str,
152+
dtype: torch.dtype,
153+
mmap_mode: Optional[str],
154+
return_meta: bool,
155+
validate_paths: bool,
156+
extract_labels: bool,
157+
allow_pickle: bool,
158+
floor_at_zero: bool,
159+
normalize_log1p: bool,
152160
transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
153-
dtype: torch.dtype = torch.float32,
154-
mmap_mode: Optional[str] = "r",
155-
return_meta: bool = True,
156-
validate_paths: bool = False,
157-
extract_labels: bool = False,
158161
labels_key_map: Optional[Dict[str, List[str]]] = None,
159-
allow_pickle: bool = True,
160-
floor_at_zero: bool = True,
161-
normalize_log1p: bool = False,
162162
) -> None:
163163
super().__init__()
164164
self.manifest_path = manifest_path

0 commit comments

Comments
 (0)