Skip to content

Commit 0f59c1f

Browse files
committed
fix: Remove demo testing calls from inference script
1 parent f5ed98c commit 0f59c1f

2 files changed

Lines changed: 28 additions & 22 deletions

File tree

src/ui/app/model_inference.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def load_model(self):
2626
try:
2727
if not self.model_path.exists():
2828
print(f"Warning: Model file not found at {self.model_path}")
29-
print("Running in demo mode with dummy predictions")
3029
return
3130

3231
# Import the model class from copied model files
@@ -50,7 +49,6 @@ def load_model(self):
5049

5150
except Exception as e:
5251
print(f"Error loading model: {e}")
53-
print("Running in demo mode with dummy predictions")
5452
self.model = None
5553

5654
def preprocess_data(self, x: List[float], y: List[float]) -> torch.Tensor:
@@ -101,8 +99,11 @@ def predict(self, x: List[float], y: List[float]) -> Dict:
10199
Dictionary with predictions and confidence scores
102100
"""
103101
if self.model is None:
104-
# Return dummy predictions for demo mode
105-
return self._dummy_predictions()
102+
return {
103+
"status": "error",
104+
"error": f"Model not loaded. Expected checkpoint at {self.model_path}",
105+
"http_status": 503
106+
}
106107

107108
try:
108109
# Preprocess data
@@ -114,24 +115,26 @@ def predict(self, x: List[float], y: List[float]) -> Dict:
114115

115116
# Process output (adjust based on your model's output format)
116117
processed = self._process_model_output(output)
118+
overall_confidence = self._compute_overall_confidence(processed)
117119

118120
predictions = {
119121
"status": "success",
120122
"predictions": processed,
121-
"confidence": 0.85, # Placeholder - calculate from model output
122123
"model_info": {
123124
"type": "AlphaDiffract",
124125
"device": str(self.device)
125126
}
126127
}
128+
if overall_confidence is not None:
129+
predictions["confidence"] = overall_confidence
127130

128131
return predictions
129132

130133
except Exception as e:
131134
return {
132135
"status": "error",
133136
"error": str(e),
134-
"predictions": self._dummy_predictions()
137+
"http_status": 500
135138
}
136139

137140
def _process_model_output(self, output) -> Dict:
@@ -224,13 +227,19 @@ def _process_model_output(self, output) -> Dict:
224227
# Handle tensor output
225228
elif isinstance(output, torch.Tensor):
226229
probs = output.cpu().numpy()
230+
confidence = None
231+
if output.ndim >= 1 and output.shape[-1] > 1:
232+
prob_tensor = torch.softmax(output, dim=-1)
233+
confidence = float(prob_tensor.max().item())
234+
227235
predictions = [
228236
{
229237
"phase": f"Predicted Phase",
230-
"confidence": 0.82,
231238
"details": f"Output shape: {probs.shape}"
232239
}
233240
]
241+
if confidence is not None:
242+
predictions[0]["confidence"] = confidence
234243

235244
return {
236245
"phase_predictions": predictions,
@@ -257,18 +266,14 @@ def _get_space_group_symbol(self, sg_number: int) -> str:
257266
return f"SG{sg_number}"
258267
except Exception:
259268
return f"SG{sg_number}"
260-
261-
def _dummy_predictions(self) -> Dict:
262-
"""Return dummy predictions when model is not available"""
263-
return {
264-
"status": "demo",
265-
"message": "Running in demo mode - model not loaded",
266-
"predictions": {
267-
"phase_predictions": [
268-
{"phase": "Demo Phase 1", "confidence": 0.75},
269-
{"phase": "Demo Phase 2", "confidence": 0.45},
270-
],
271-
"intensity_profile": []
272-
},
273-
"confidence": 0.60
274-
}
269+
270+
def _compute_overall_confidence(self, processed: Dict) -> Optional[float]:
271+
"""Compute overall confidence from available per-phase confidences."""
272+
phase_predictions = processed.get("phase_predictions", []) if isinstance(processed, dict) else []
273+
confidences = [
274+
float(p["confidence"]) for p in phase_predictions
275+
if isinstance(p, dict) and "confidence" in p and p["confidence"] is not None
276+
]
277+
if not confidences:
278+
return None
279+
return float(np.mean(confidences))

src/ui/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ numpy==1.24.3
55
torch==2.1.0
66
pytorch-lightning==2.5.5
77
pydantic==2.5.0
8+
spglib==2.4.0

0 commit comments

Comments
 (0)