66from typing import Dict , List , Tuple , Optional
77import torch
88import numpy as np
9+ import spglib
910
1011
1112class XRDModelInference :
@@ -32,7 +33,7 @@ def load_model(self):
3233 from .model import AlphaDiffractMultiscaleLightning
3334
3435 # Load checkpoint
35- checkpoint = torch .load (self .model_path , map_location = self .device )
36+ checkpoint = torch .load (self .model_path , map_location = self .device , weights_only = False )
3637
3738 # Initialize model (you may need to adjust hyperparameters based on your checkpoint)
3839 # For now, using placeholder values - these should match your trained model
@@ -243,20 +244,19 @@ def _process_model_output(self, output) -> Dict:
243244 }
244245
245246 def _get_space_group_symbol (self , sg_number : int ) -> str :
246- """Get space group symbol from number (simplified mapping)"""
247- # Common space group symbols - this is a simplified mapping
248- # In production, you'd want a complete lookup table
249- symbols = {
250- 1 : "P1" , 2 : "P-1" , 3 : "P2" , 4 : "P21" , 5 : "C2" ,
251- 10 : "P2/m" , 15 : "C2/c" , 16 : "P222" , 19 : "P212121" ,
252- 38 : "Amm2" , 47 : "Pmmm" , 62 : "Pnma" , 63 : "Cmcm" ,
253- 71 : "Immm" , 74 : "Imma" , 82 : "I-4" , 87 : "I4/m" ,
254- 123 : "P4/mmm" , 129 : "P4/nmm" , 139 : "I4/mmm" , 148 : "R-3" ,
255- 160 : "R3m" , 162 : "P-31m" , 164 : "P-3m1" , 166 : "R-3m" ,
256- 167 : "R-3c" , 186 : "P63mc" , 194 : "P63/mmc" , 221 : "Pm-3m" ,
257- 225 : "Fm-3m" , 227 : "Fd-3m" , 229 : "Im-3m" , 230 : "Ia-3d"
258- }
259- return symbols .get (sg_number , f"SG{ sg_number } " )
247+ """Get space group symbol from number using spglib"""
248+ if sg_number < 1 or sg_number > 230 :
249+ return f"SG{ sg_number } "
250+
251+ try :
252+ # Get space group type information from spglib
253+ sg_type = spglib .get_spacegroup_type (sg_number )
254+ if sg_type is not None :
255+ # Use the international short symbol (Hermann-Mauguin notation)
256+ return sg_type ['international_short' ]
257+ return f"SG{ sg_number } "
258+ except Exception :
259+ return f"SG{ sg_number } "
260260
261261 def _dummy_predictions (self ) -> Dict :
262262 """Return dummy predictions when model is not available"""
0 commit comments