@@ -206,6 +206,8 @@ def e_step(
206206 pred_valid : torch .Tensor ,
207207 obs_rings : Optional [torch .Tensor ] = None ,
208208 pred_rings : Optional [torch .Tensor ] = None ,
209+ obs_scan_nrs : Optional [torch .Tensor ] = None ,
210+ grain_scan_mask : Optional [torch .Tensor ] = None ,
209211 ) -> Tuple [torch .Tensor , torch .Tensor ]:
210212 """E-step: compute soft ownership probabilities.
211213
@@ -214,6 +216,10 @@ def e_step(
214216 2. |omega_obs - omega_pred| < tol_omega
215217 3. |eta_obs - eta_pred| < tol_eta
216218
219+ In pf-HEDM mode, additionally filters by scan position: grain g
220+ can only own spot s if grain g is present at the scan position
221+ where spot s was observed (grain_scan_mask[g, scanNr_of_s] == True).
222+
217223 Within the box, the closest match (smallest omega+eta distance)
218224 is used for the Gaussian ownership kernel.
219225
@@ -224,6 +230,9 @@ def e_step(
224230 pred_valid : (N, K) validity mask
225231 obs_rings : (S,) int ring indices for observed spots
226232 pred_rings : (N, K) int ring indices for predicted spots
233+ obs_scan_nrs : (S,) int scan number for each observed spot (pf-HEDM)
234+ grain_scan_mask : (N, n_scans) bool, grain_scan_mask[g, s] = True if
235+ grain g is present at scan s. Derived from the C code's sinograms.
227236
228237 Returns
229238 -------
@@ -239,13 +248,26 @@ def e_step(
239248 sigma2 = self .sigma ** 2
240249 TWO_PI = 2.0 * math .pi
241250 use_rings = obs_rings is not None and pred_rings is not None
251+ use_scan_filter = (obs_scan_nrs is not None and
252+ grain_scan_mask is not None )
242253
243254 ownership = torch .zeros (S , N , dtype = obs_spots .dtype , device = obs_spots .device )
244255 hkl_assignments = torch .zeros (S , N , dtype = torch .long , device = obs_spots .device )
245256
246257 obs_eta = obs_spots [:, 1 ] # (S,)
247258 obs_omega = obs_spots [:, 2 ] # (S,)
248259
260+ # Pre-compute per-spot scan eligibility for each grain
261+ # scan_ok[v] is (S,) bool: True if grain v is present at that spot's scan
262+ scan_ok_per_grain = []
263+ if use_scan_filter :
264+ for v in range (N ):
265+ # grain_scan_mask[v] is (n_scans,) bool
266+ # obs_scan_nrs is (S,) int
267+ # Index into grain_scan_mask for each spot's scan number
268+ scan_ok = grain_scan_mask [v ][obs_scan_nrs ] # (S,) bool
269+ scan_ok_per_grain .append (scan_ok )
270+
249271 for v in range (N ):
250272 valid_mask = pred_valid [v ] > 0.5
251273 if not valid_mask .any ():
@@ -278,6 +300,10 @@ def e_step(
278300
279301 in_box = ring_match & (d_omega .abs () < self .tol_omega ) & (d_eta .abs () < self .tol_eta )
280302
303+ # Scan position filter: grain v must be present at this spot's scan
304+ if use_scan_filter :
305+ in_box = in_box & scan_ok_per_grain [v ]
306+
281307 # Distance for Gaussian kernel (only eta + omega, 2theta is handled by ring)
282308 dist = torch .sqrt (d_omega ** 2 + d_eta ** 2 )
283309
@@ -460,6 +486,8 @@ def fit_from_orient(
460486 orient_matrices : torch .Tensor ,
461487 positions : torch .Tensor ,
462488 obs_rings : Optional [torch .Tensor ] = None ,
489+ obs_scan_nrs : Optional [torch .Tensor ] = None ,
490+ grain_scan_mask : Optional [torch .Tensor ] = None ,
463491 n_iter : int = 10 ,
464492 verbose : bool = True ,
465493 ) -> EMResult :
@@ -475,6 +503,12 @@ def fit_from_orient(
475503 orient_matrices : (N, 3, 3) orientation matrices
476504 positions : (N, 3) grain positions (micrometers)
477505 obs_rings : (S,) int ring indices for observed spots
506+ obs_scan_nrs : (S,) int scan number per observed spot (pf-HEDM).
507+ If provided along with grain_scan_mask, the E-step filters
508+ ownership by scan position.
509+ grain_scan_mask : (N, n_scans) bool tensor. True if grain g is
510+ present at scan s. Derived from the C code's sinograms:
511+ grain_scan_mask[g, s] = any(sinos[g, :, s] > 0).
478512 n_iter : int
479513 Number of EM iterations (E-step + sigma annealing).
480514 verbose : bool
@@ -494,6 +528,8 @@ def fit_from_orient(
494528 ownership , hkl_assignments = self .e_step (
495529 obs_spots , pred_coords , pred_valid ,
496530 obs_rings = obs_rings , pred_rings = pred_rings ,
531+ obs_scan_nrs = obs_scan_nrs ,
532+ grain_scan_mask = grain_scan_mask ,
497533 )
498534
499535 with torch .no_grad ():
@@ -512,6 +548,8 @@ def fit_from_orient(
512548 ownership , hkl_assignments = self .e_step (
513549 obs_spots , pred_coords , pred_valid ,
514550 obs_rings = obs_rings , pred_rings = pred_rings ,
551+ obs_scan_nrs = obs_scan_nrs ,
552+ grain_scan_mask = grain_scan_mask ,
515553 )
516554
517555 # Return orient_matrices reshaped as (N, 9) to fit EMResult euler_angles slot
0 commit comments