33import torch
44import torch .nn as nn
55import torch .nn .functional as F
6- import numpy as np
76
87import pytorch_lightning as pl
98
@@ -183,7 +182,6 @@ def make_mlp(
183182
184183
185184# -----------------------------
186- # OG-style Multiscale CNN Backbone (1D) with ConvNeXt-like blocks
187185# Mirrors alphadiffract.model.MultiscaleCNNBackbone behavior:
188186# - sequential conv stages with specified kernel_sizes and strides
189187# - optional average/max pooling between stages and at the end
@@ -212,13 +210,11 @@ def __init__(
212210 self .dim_in = dim_in
213211 self .output_type = output_type
214212
215- # Build per-stage dropout schedule
216213 if ramped_dropout_rate :
217214 dropout_per_stage = torch .linspace (0.0 , dropout_rate , steps = len (channels )).tolist ()
218215 else :
219216 dropout_per_stage = [dropout_rate ] * len (channels )
220217
221- # Select pooling module
222218 if pooling_type == "average" :
223219 pool_cls = nn .AvgPool1d
224220 pool_kwargs = {"kernel_size" : 3 , "stride" : 2 }
@@ -231,7 +227,6 @@ def __init__(
231227 layers : List [nn .Module ] = []
232228 in_ch = 1
233229 for i , (out_ch , k , s ) in enumerate (zip (channels , kernel_sizes , strides )):
234- # Build stage block matching OG ConvNextBlock1DAdaptorForMultiscaleCNN
235230 stage_block = ConvNextBlock1DAdaptor (
236231 in_channels = in_ch ,
237232 out_channels = out_ch ,
@@ -284,7 +279,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
284279class AlphaDiffractMultiscaleLightning (pl .LightningModule ):
285280 def __init__ (
286281 self ,
287- # Backbone params (OG-style)
288282 dim_in : int ,
289283 channels : Tuple [int , ...],
290284 kernel_sizes : Tuple [int , ...],
@@ -300,36 +294,31 @@ def __init__(
300294 layer_scale_init_value : float ,
301295 drop_path_rate : float ,
302296
303- # Heads
304297 head_dropout : float ,
305298 cs_hidden : Optional [Tuple [int , ...]],
306299 sg_hidden : Optional [Tuple [int , ...]],
307300 lp_hidden : Optional [Tuple [int , ...]],
308301
309- # Task sizes
310302 num_cs_classes : int ,
311303 num_sg_classes : int ,
312304 num_lp_outputs : int ,
313305
314- # LP bounding
315306 lp_bounds_min : Tuple [float , float , float , float , float , float ],
316307 lp_bounds_max : Tuple [float , float , float , float , float , float ],
317308 bound_lp_with_sigmoid : bool ,
318309
319- # Loss weights
320310 lambda_cs : float ,
321311 lambda_sg : float ,
322312 lambda_lp : float ,
323313
324- # Optimizer
325314 lr : float ,
326315 weight_decay : float ,
327316 use_adamw : bool ,
317+ steps_per_epoch : int ,
328318 ):
329319 super ().__init__ ()
330320 self .save_hyperparameters ()
331321
332- # Backbone
333322 self .backbone = MultiscaleCNNBackbone1D (
334323 dim_in = dim_in ,
335324 channels = channels ,
@@ -348,7 +337,6 @@ def __init__(
348337 )
349338 feat_dim = self .backbone .dim_output
350339
351- # Heads
352340 self .cs_head = make_mlp (
353341 input_dim = feat_dim ,
354342 hidden_dims = cs_hidden ,
@@ -371,22 +359,20 @@ def __init__(
371359 output_activation = None ,
372360 )
373361
374- # Losses and bounds
375362 self .ce = nn .CrossEntropyLoss ()
376363 self .mse = nn .MSELoss ()
377364 self .register_buffer ("lp_min" , torch .tensor (lp_bounds_min , dtype = torch .float32 ))
378365 self .register_buffer ("lp_max" , torch .tensor (lp_bounds_max , dtype = torch .float32 ))
379366 self .bound_lp_with_sigmoid = bound_lp_with_sigmoid
380367
381- # weights and optim config
382368 self .lambda_cs = lambda_cs
383369 self .lambda_sg = lambda_sg
384370 self .lambda_lp = lambda_lp
385371 self .lr = lr
386372 self .weight_decay = weight_decay
387373 self .use_adamw = use_adamw
374+ self .steps_per_epoch = steps_per_epoch
388375
389- # Task sizes
390376 self .num_cs_classes = num_cs_classes
391377 self .num_sg_classes = num_sg_classes
392378 self .num_lp_outputs = num_lp_outputs
@@ -523,23 +509,7 @@ def configure_optimizers(self):
523509 else :
524510 optimizer = torch .optim .Adam (params , lr = self .lr , weight_decay = self .weight_decay )
525511
526- # Compute steps per epoch to match OG scheduler semantics:
527- # step_size_up = 6 * iterations_per_epoch
528- steps_per_epoch = None
529- try :
530- if hasattr (self , "trainer" ) and self .trainer is not None :
531- total_steps = getattr (self .trainer , "estimated_stepping_batches" , None )
532- max_epochs = getattr (self .trainer , "max_epochs" , None )
533- if total_steps is not None and max_epochs is not None and max_epochs > 0 :
534- steps_per_epoch = max (1 , total_steps // max_epochs )
535- except Exception :
536- pass
537-
538- if steps_per_epoch is None :
539- # Fallback if trainer hooks are unavailable; use a conservative default
540- steps_per_epoch = 100
541-
542- step_size_up = int (6 * steps_per_epoch )
512+ step_size_up = int (6 * self .steps_per_epoch )
543513
544514 scheduler = torch .optim .lr_scheduler .CyclicLR (
545515 optimizer ,
0 commit comments