@@ -26,7 +26,6 @@ def load_model(self):
2626 try :
2727 if not self .model_path .exists ():
2828 print (f"Warning: Model file not found at { self .model_path } " )
29- print ("Running in demo mode with dummy predictions" )
3029 return
3130
3231 # Import the model class from copied model files
@@ -50,7 +49,6 @@ def load_model(self):
5049
5150 except Exception as e :
5251 print (f"Error loading model: { e } " )
53- print ("Running in demo mode with dummy predictions" )
5452 self .model = None
5553
5654 def preprocess_data (self , x : List [float ], y : List [float ]) -> torch .Tensor :
@@ -101,8 +99,11 @@ def predict(self, x: List[float], y: List[float]) -> Dict:
10199 Dictionary with predictions and confidence scores
102100 """
103101 if self .model is None :
104- # Return dummy predictions for demo mode
105- return self ._dummy_predictions ()
102+ return {
103+ "status" : "error" ,
104+ "error" : f"Model not loaded. Expected checkpoint at { self .model_path } " ,
105+ "http_status" : 503
106+ }
106107
107108 try :
108109 # Preprocess data
@@ -114,24 +115,26 @@ def predict(self, x: List[float], y: List[float]) -> Dict:
114115
115116 # Process output (adjust based on your model's output format)
116117 processed = self ._process_model_output (output )
118+ overall_confidence = self ._compute_overall_confidence (processed )
117119
118120 predictions = {
119121 "status" : "success" ,
120122 "predictions" : processed ,
121- "confidence" : 0.85 , # Placeholder - calculate from model output
122123 "model_info" : {
123124 "type" : "AlphaDiffract" ,
124125 "device" : str (self .device )
125126 }
126127 }
128+ if overall_confidence is not None :
129+ predictions ["confidence" ] = overall_confidence
127130
128131 return predictions
129132
130133 except Exception as e :
131134 return {
132135 "status" : "error" ,
133136 "error" : str (e ),
134- "predictions " : self . _dummy_predictions ()
137+ "http_status " : 500
135138 }
136139
137140 def _process_model_output (self , output ) -> Dict :
@@ -224,13 +227,19 @@ def _process_model_output(self, output) -> Dict:
224227 # Handle tensor output
225228 elif isinstance (output , torch .Tensor ):
226229 probs = output .cpu ().numpy ()
230+ confidence = None
231+ if output .ndim >= 1 and output .shape [- 1 ] > 1 :
232+ prob_tensor = torch .softmax (output , dim = - 1 )
233+ confidence = float (prob_tensor .max ().item ())
234+
227235 predictions = [
228236 {
229237 "phase" : f"Predicted Phase" ,
230- "confidence" : 0.82 ,
231238 "details" : f"Output shape: { probs .shape } "
232239 }
233240 ]
241+ if confidence is not None :
242+ predictions [0 ]["confidence" ] = confidence
234243
235244 return {
236245 "phase_predictions" : predictions ,
@@ -257,18 +266,14 @@ def _get_space_group_symbol(self, sg_number: int) -> str:
257266 return f"SG{ sg_number } "
258267 except Exception :
259268 return f"SG{ sg_number } "
260-
261- def _dummy_predictions (self ) -> Dict :
262- """Return dummy predictions when model is not available"""
263- return {
264- "status" : "demo" ,
265- "message" : "Running in demo mode - model not loaded" ,
266- "predictions" : {
267- "phase_predictions" : [
268- {"phase" : "Demo Phase 1" , "confidence" : 0.75 },
269- {"phase" : "Demo Phase 2" , "confidence" : 0.45 },
270- ],
271- "intensity_profile" : []
272- },
273- "confidence" : 0.60
274- }
269+
270+ def _compute_overall_confidence (self , processed : Dict ) -> Optional [float ]:
271+ """Compute overall confidence from available per-phase confidences."""
272+ phase_predictions = processed .get ("phase_predictions" , []) if isinstance (processed , dict ) else []
273+ confidences = [
274+ float (p ["confidence" ]) for p in phase_predictions
275+ if isinstance (p , dict ) and "confidence" in p and p ["confidence" ] is not None
276+ ]
277+ if not confidences :
278+ return None
279+ return float (np .mean (confidences ))
0 commit comments