Skip to content

Commit 255dfa7

Browse files
committed
Refactor: Continued cleanup
1 parent 7d15173 commit 255dfa7

5 files changed

Lines changed: 21 additions & 96 deletions

File tree

configs/trainer.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# AlphaDiffract trainer configuration — ConvNeXt (paper-matching lightweight variant)
2-
# Use with: PYTHONPATH=src python -m trainer.train_paper configs/trainer_convnext_paper.yaml
3-
41
data:
52
manifest_dir: "../../../ad_data/manifests"
63
dataset_root: "../../../ad_data/data/dataset"
@@ -40,7 +37,8 @@ data:
4037
dtype: "float32"
4138
mmap_mode: null
4239
floor_at_zero: true
43-
normalize_log1p: False # paper used log1p preprocessing
40+
normalize_log1p: False
41+
shift_labels: true
4442

4543
augmentation:
4644
noise_poisson_range: [1.0, 100.0]

src/trainer/dataset/datamodule.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,6 @@
44
This module wires NpyManifestDataset to PyTorch Lightning and can optionally
55
auto-generate manifests from a dataset root if they are missing.
66
7-
Typical usage:
8-
from pytorch_lightning import Trainer
9-
from dataset.datamodule import NpyDataModule
10-
11-
dm = NpyDataModule(
12-
manifest_dir="data/manifests",
13-
batch_size=64,
14-
num_workers=8,
15-
pin_memory=True,
16-
persistent_workers=True,
17-
# Optional: auto-generate manifests if missing
18-
dataset_root="data/dataset",
19-
auto_generate_manifests=True,
20-
train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42,
21-
# Dataset-specific kwargs
22-
dataset_kwargs={"dtype": torch.float32, "mmap_mode": "r", "return_meta": True},
23-
)
24-
25-
trainer = Trainer(max_epochs=10, accelerator="auto", devices="auto")
26-
trainer.fit(model, dm)
27-
trainer.test(model, dm)
28-
297
Notes:
308
- Splitting is performed by material ID when generating manifests (never per-file).
319
- Manifests avoid scanning the entire dataset during training.
@@ -80,34 +58,6 @@ def _transform(x: torch.Tensor) -> torch.Tensor:
8058
return _transform
8159

8260

83-
84-
def _shift_one_based_collate(batch):
85-
"""
86-
Collate function that uses PyTorch's default_collate, then unconditionally shifts
87-
cs and sg labels by -1 (assumes 1-based input). Performed under torch.no_grad to avoid
88-
constructing any graphs.
89-
"""
90-
collated = default_collate(batch)
91-
with torch.no_grad():
92-
def _shift(t):
93-
return t - 1 if torch.is_tensor(t) else t
94-
95-
if isinstance(collated, dict):
96-
if "cs" in collated:
97-
collated["cs"] = _shift(collated["cs"])
98-
if "sg" in collated:
99-
collated["sg"] = _shift(collated["sg"])
100-
elif isinstance(collated, (list, tuple)):
101-
# Tuple-based batches: (x, cs, sg, [lp])
102-
lst = list(collated)
103-
if len(lst) >= 2:
104-
lst[1] = _shift(lst[1])
105-
if len(lst) >= 3:
106-
lst[2] = _shift(lst[2])
107-
collated = type(collated)(lst)
108-
return collated
109-
110-
11161
class NpyDataModule(pl.LightningDataModule):
11262
"""
11363
LightningDataModule that reads train/val/test JSONL manifests and constructs DataLoaders.
@@ -165,9 +115,7 @@ def __init__(
165115
self.persistent_workers = persistent_workers
166116
self.prefetch_factor = prefetch_factor
167117
self.collate_fn = collate_fn
168-
if self.collate_fn is None:
169-
self.collate_fn = _shift_one_based_collate
170-
118+
171119
self.dataset_cls = dataset_cls
172120
self.dataset_kwargs = dataset_kwargs or {}
173121

src/trainer/dataset/dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(
157157
allow_pickle: bool,
158158
floor_at_zero: bool,
159159
normalize_log1p: bool,
160+
shift_labels: bool,
160161
transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
161162
labels_key_map: Optional[Dict[str, List[str]]] = None,
162163
) -> None:
@@ -170,6 +171,7 @@ def __init__(
170171
self.allow_pickle = allow_pickle
171172
self.floor_at_zero = floor_at_zero
172173
self.normalize_log1p = normalize_log1p
174+
self.shift_labels = shift_labels
173175
# Default key mapping for extracting fields from embedded containers
174176
# Simplified: single string keys, no search lists
175177
self.labels_key_map = labels_key_map or {
@@ -311,8 +313,12 @@ def _get_exact(container, key: str):
311313
# Attach labels if present/extracted
312314
if self.extract_labels:
313315
if y_cs_t is not None:
316+
if self.shift_labels:
317+
y_cs_t = y_cs_t - 1
314318
sample["cs"] = y_cs_t
315319
if y_sg_t is not None:
320+
if self.shift_labels:
321+
y_sg_t = y_sg_t - 1
316322
sample["sg"] = y_sg_t
317323
if y_lp_t is not None:
318324
sample["lattice_params"] = y_lp_t

src/trainer/model/model.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
6-
import numpy as np
76

87
import pytorch_lightning as pl
98

@@ -183,7 +182,6 @@ def make_mlp(
183182

184183

185184
# -----------------------------
186-
# OG-style Multiscale CNN Backbone (1D) with ConvNeXt-like blocks
187185
# Mirrors alphadiffract.model.MultiscaleCNNBackbone behavior:
188186
# - sequential conv stages with specified kernel_sizes and strides
189187
# - optional average/max pooling between stages and at the end
@@ -212,13 +210,11 @@ def __init__(
212210
self.dim_in = dim_in
213211
self.output_type = output_type
214212

215-
# Build per-stage dropout schedule
216213
if ramped_dropout_rate:
217214
dropout_per_stage = torch.linspace(0.0, dropout_rate, steps=len(channels)).tolist()
218215
else:
219216
dropout_per_stage = [dropout_rate] * len(channels)
220217

221-
# Select pooling module
222218
if pooling_type == "average":
223219
pool_cls = nn.AvgPool1d
224220
pool_kwargs = {"kernel_size": 3, "stride": 2}
@@ -231,7 +227,6 @@ def __init__(
231227
layers: List[nn.Module] = []
232228
in_ch = 1
233229
for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)):
234-
# Build stage block matching OG ConvNextBlock1DAdaptorForMultiscaleCNN
235230
stage_block = ConvNextBlock1DAdaptor(
236231
in_channels=in_ch,
237232
out_channels=out_ch,
@@ -284,7 +279,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
284279
class AlphaDiffractMultiscaleLightning(pl.LightningModule):
285280
def __init__(
286281
self,
287-
# Backbone params (OG-style)
288282
dim_in: int,
289283
channels: Tuple[int, ...],
290284
kernel_sizes: Tuple[int, ...],
@@ -300,36 +294,31 @@ def __init__(
300294
layer_scale_init_value: float,
301295
drop_path_rate: float,
302296

303-
# Heads
304297
head_dropout: float,
305298
cs_hidden: Optional[Tuple[int, ...]],
306299
sg_hidden: Optional[Tuple[int, ...]],
307300
lp_hidden: Optional[Tuple[int, ...]],
308301

309-
# Task sizes
310302
num_cs_classes: int,
311303
num_sg_classes: int,
312304
num_lp_outputs: int,
313305

314-
# LP bounding
315306
lp_bounds_min: Tuple[float, float, float, float, float, float],
316307
lp_bounds_max: Tuple[float, float, float, float, float, float],
317308
bound_lp_with_sigmoid: bool,
318309

319-
# Loss weights
320310
lambda_cs: float,
321311
lambda_sg: float,
322312
lambda_lp: float,
323313

324-
# Optimizer
325314
lr: float,
326315
weight_decay: float,
327316
use_adamw: bool,
317+
steps_per_epoch: int,
328318
):
329319
super().__init__()
330320
self.save_hyperparameters()
331321

332-
# Backbone
333322
self.backbone = MultiscaleCNNBackbone1D(
334323
dim_in=dim_in,
335324
channels=channels,
@@ -348,7 +337,6 @@ def __init__(
348337
)
349338
feat_dim = self.backbone.dim_output
350339

351-
# Heads
352340
self.cs_head = make_mlp(
353341
input_dim=feat_dim,
354342
hidden_dims=cs_hidden,
@@ -371,22 +359,20 @@ def __init__(
371359
output_activation=None,
372360
)
373361

374-
# Losses and bounds
375362
self.ce = nn.CrossEntropyLoss()
376363
self.mse = nn.MSELoss()
377364
self.register_buffer("lp_min", torch.tensor(lp_bounds_min, dtype=torch.float32))
378365
self.register_buffer("lp_max", torch.tensor(lp_bounds_max, dtype=torch.float32))
379366
self.bound_lp_with_sigmoid = bound_lp_with_sigmoid
380367

381-
# weights and optim config
382368
self.lambda_cs = lambda_cs
383369
self.lambda_sg = lambda_sg
384370
self.lambda_lp = lambda_lp
385371
self.lr = lr
386372
self.weight_decay = weight_decay
387373
self.use_adamw = use_adamw
374+
self.steps_per_epoch = steps_per_epoch
388375

389-
# Task sizes
390376
self.num_cs_classes = num_cs_classes
391377
self.num_sg_classes = num_sg_classes
392378
self.num_lp_outputs = num_lp_outputs
@@ -523,23 +509,7 @@ def configure_optimizers(self):
523509
else:
524510
optimizer = torch.optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay)
525511

526-
# Compute steps per epoch to match OG scheduler semantics:
527-
# step_size_up = 6 * iterations_per_epoch
528-
steps_per_epoch = None
529-
try:
530-
if hasattr(self, "trainer") and self.trainer is not None:
531-
total_steps = getattr(self.trainer, "estimated_stepping_batches", None)
532-
max_epochs = getattr(self.trainer, "max_epochs", None)
533-
if total_steps is not None and max_epochs is not None and max_epochs > 0:
534-
steps_per_epoch = max(1, total_steps // max_epochs)
535-
except Exception:
536-
pass
537-
538-
if steps_per_epoch is None:
539-
# Fallback if trainer hooks are unavailable; use a conservative default
540-
steps_per_epoch = 100
541-
542-
step_size_up = int(6 * steps_per_epoch)
512+
step_size_up = int(6 * self.steps_per_epoch)
543513

544514
scheduler = torch.optim.lr_scheduler.CyclicLR(
545515
optimizer,

src/trainer/train.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,18 @@
66
import yaml
77
from pytorch_lightning import Trainer, seed_everything
88
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
9-
import signal
109
from pytorch_lightning.loggers import CSVLogger
1110
try:
1211
from pytorch_lightning.loggers import MLFlowLogger
1312
except Exception:
1413
MLFlowLogger = None
1514

16-
# Project imports (expect PYTHONPATH=src or run via `python -m trainer.train_paper`)
1715
from dataset import NpyDataModule
1816
from model.model import AlphaDiffractMultiscaleLightning
1917

2018

2119
def parse_args() -> argparse.Namespace:
2220
p = argparse.ArgumentParser(description="Train AlphaDiffract paper model (config-required)")
23-
# Require a config file path with no script-side defaults
2421
p.add_argument("config", type=str, help="Path to trainer config YAML (e.g., configs/trainer.yaml)")
2522
return p.parse_args()
2623

@@ -63,6 +60,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule:
6360
"allow_pickle": prep_cfg["allow_pickle"],
6461
"floor_at_zero": prep_cfg["floor_at_zero"],
6562
"normalize_log1p": prep_cfg["normalize_log1p"],
63+
"shift_labels": prep_cfg["shift_labels"],
6664
}
6765
labels_key_map = prep_cfg["labels_key_map"]
6866
if labels_key_map is not None:
@@ -95,7 +93,7 @@ def build_datamodule_from_cfg(cfg: Dict[str, Any]) -> NpyDataModule:
9593
return dm
9694

9795

98-
def build_model_from_cfg(cfg: Dict[str, Any]):
96+
def build_model_from_cfg(cfg: Dict[str, Any], steps_per_epoch: int):
9997
model_cfg = cfg["model"]
10098
backbone_cfg = model_cfg["backbone"]
10199
heads_cfg = model_cfg["heads"]
@@ -149,10 +147,11 @@ def build_model_from_cfg(cfg: Dict[str, Any]):
149147
lambda_cs=loss_cfg["lambda_cs"],
150148
lambda_sg=loss_cfg["lambda_sg"],
151149
lambda_lp=loss_cfg["lambda_lp"],
152-
150+
153151
lr=optim_cfg["lr"],
154152
weight_decay=optim_cfg["weight_decay"],
155153
use_adamw=optim_cfg["use_adamw"],
154+
steps_per_epoch=steps_per_epoch,
156155
)
157156
else:
158157
raise ValueError(f"Unsupported model_type '{model_type}'. Expected 'multiscale'.")
@@ -292,7 +291,11 @@ def main():
292291
torch.set_float32_matmul_precision('high')
293292

294293
dm = build_datamodule_from_cfg(cfg)
295-
model = build_model_from_cfg(cfg)
294+
# Explicitly setup datamodule to calculate steps_per_epoch
295+
dm.setup("fit")
296+
steps_per_epoch = len(dm.train_dataloader())
297+
298+
model = build_model_from_cfg(cfg, steps_per_epoch=steps_per_epoch)
296299
trainer = build_trainer_from_cfg(cfg, raw_config_path=args.config)
297300

298301
# Train

0 commit comments

Comments
 (0)