Skip to content

Commit 0d0cd66

Browse files
authored
Merge pull request #1 from AdvancedPhotonSource/replication_finalize
Add Training System
2 parents 13978a6 + 255dfa7 commit 0d0cd66

10 files changed

Lines changed: 2379 additions & 1 deletion

File tree

.gitignore

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
.env
2+
__pycache__
23

34
# Development
45
/sandbox
56
/staging
6-
/data
7+
/data
8+
/original
9+
og
10+
11+
# Non-Docker Training Outputs
12+
/src/trainer/outputs
13+
/src/trainer/mlruns
14+
15+
# Temp uv setup
16+
.python-version
17+
pyproject.toml
18+
uv.lock

configs/trainer.yaml

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
data:
2+
manifest_dir: "../../../ad_data/manifests"
3+
dataset_root: "../../../ad_data/data/dataset"
4+
extra_val_file: "rruff.jsonl"
5+
auto_generate_manifests: true
6+
train_ratio: 0.8
7+
val_ratio: 0.1
8+
test_ratio: 0.1
9+
seed: 42
10+
11+
loader:
12+
# --- DataLoader ---
13+
batch_size: 64 # match OG run (64 per process)
14+
num_workers: 8
15+
pin_memory: true
16+
persistent_workers: true
17+
prefetch_factor: 2
18+
train_file: "train.jsonl"
19+
val_file: "val.jsonl"
20+
test_file: "test.jsonl"
21+
22+
preprocessing:
23+
validate_paths: false
24+
extract_labels: true
25+
allow_pickle: true
26+
labels_key_map:
27+
x: "dp"
28+
cs: "cs"
29+
sg: "sg"
30+
lattice_params: null
31+
lp_a: "_cell_length_a"
32+
lp_b: "_cell_length_b"
33+
lp_c: "_cell_length_c"
34+
lp_alpha: "_cell_angle_alpha"
35+
lp_beta: "_cell_angle_beta"
36+
lp_gamma: "_cell_angle_gamma"
37+
dtype: "float32"
38+
mmap_mode: null
39+
floor_at_zero: true
40+
normalize_log1p: False
41+
shift_labels: true
42+
43+
augmentation:
44+
noise_poisson_range: [1.0, 100.0]
45+
noise_gaussian_range: [0.001, 0.1]
46+
standardize_to: [0.0, 100.0]
47+
48+
model:
49+
type: "multiscale"
50+
51+
backbone:
52+
dim_in: 8192
53+
dims: [80, 80, 80]
54+
kernel_sizes: [100, 50, 25]
55+
strides: [5, 5, 5]
56+
dropout_rate: 0.3
57+
layer_scale_init_value: 0.0
58+
drop_path_rate: 0.3
59+
ramped_dropout_rate: false
60+
block_type: "convnext"
61+
pooling_type: "average"
62+
final_pool: true
63+
use_batchnorm: false
64+
activation: "leaky_relu"
65+
output_type: "flatten"
66+
67+
heads:
68+
head_dropout: 0.5
69+
cs_hidden: [2300, 1150]
70+
sg_hidden: [2300, 1150]
71+
lp_hidden: [512, 256]
72+
73+
tasks:
74+
num_cs_classes: 7
75+
num_sg_classes: 230
76+
num_lp_outputs: 6
77+
78+
lp_bounds_min: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
79+
lp_bounds_max: [300.0, 300.0, 300.0, 180.0, 180.0, 180.0]
80+
bound_lp_with_sigmoid: true
81+
82+
loss:
83+
lambda_cs: 1.0
84+
lambda_sg: 1.0
85+
lambda_lp: 1.0
86+
87+
gemd_mu: 0.0
88+
gemd_distance_matrix_path: null
89+
90+
optimizer:
91+
lr: 0.0002
92+
weight_decay: 0.01
93+
use_adamw: true
94+
gradient_clip_val: 1.0
95+
gradient_clip_algorithm: "norm"
96+
97+
trainer:
98+
default_root_dir: "outputs/convnext_paper"
99+
max_epochs: 100
100+
accumulate_grad_batches: 1
101+
precision: "32" # match OG (AMP disabled)
102+
accelerator: "gpu"
103+
devices: 1
104+
log_every_n_steps: 200
105+
deterministic: false
106+
benchmark: true
107+
108+
logging:
109+
logger: "mlflow"
110+
csv_logger_name: "model_logs_convnext_paper"
111+
mlflow_experiment_name: "AlphaDiffract_Paper_ConvNeXt"
112+
mlflow_tracking_uri: null
113+
mlflow_run_name: "ConvNeXt_Paper_Run"
114+
115+
checkpointing:
116+
monitor: "val/loss"
117+
mode: "min"
118+
save_top_k: 1
119+
every_n_epochs: 1
120+
121+
resume_from: null
122+
test_after_train: true

src/trainer/dataset/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
Dataset package for training.
3+
4+
Assumptions:
5+
- Import paths use the local 'dataset' package.
6+
- Manifests are JSON Lines and may include an optional first-line meta header:
7+
{"__meta__": {"version": 1, "base_dir": "<path to dataset root>"}}
8+
When present, non-absolute file paths in records are resolved relative to base_dir.
9+
base_dir itself may be relative to the manifest file's directory. This makes manifests
10+
independent of the current working directory.
11+
- Legacy manifests without the meta header remain supported; their file paths are used as-is.
12+
13+
Exports:
14+
- NpyManifestDataset: Map-style dataset loading .npy files listed in JSONL manifests.
15+
- NpyDataModule: LightningDataModule wiring datasets and DataLoaders.
16+
- generate_manifests: Utility to create train/val/test manifests split by material ID.
17+
- ManifestStats: Summary dataclass for manifest generation.
18+
"""
19+
20+
from .dataset import NpyManifestDataset, default_manifest_paths
21+
from .datamodule import NpyDataModule
22+
from .manifest_utils import (
23+
generate_manifests,
24+
ManifestStats,
25+
scan_dataset_root,
26+
split_materials,
27+
write_jsonl_manifest,
28+
)
29+
30+
__all__ = [
31+
"NpyManifestDataset",
32+
"default_manifest_paths",
33+
"NpyDataModule",
34+
"generate_manifests",
35+
"ManifestStats",
36+
"scan_dataset_root",
37+
"split_materials",
38+
"write_jsonl_manifest",
39+
]

0 commit comments

Comments
 (0)