|
| 1 | +""" |
| 2 | +FastAPI main application for XRD Analysis Tool. |
| 3 | +Serves both the API endpoints and the static React frontend. |
| 4 | +""" |
| 5 | +from fastapi import FastAPI, UploadFile, File |
| 6 | +from fastapi.staticfiles import StaticFiles |
| 7 | +from fastapi.responses import FileResponse, JSONResponse |
| 8 | +from fastapi.middleware.cors import CORSMiddleware |
| 9 | +from pathlib import Path |
| 10 | +from typing import Dict, List |
| 11 | +import torch |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from .model_inference import XRDModelInference |
| 15 | + |
| 16 | +# Initialize FastAPI app |
| 17 | +app = FastAPI( |
| 18 | + title="XRD Analysis API", |
| 19 | + description="API for analyzing Powder XRD data", |
| 20 | + version="1.0.0" |
| 21 | +) |
| 22 | + |
| 23 | +# Configure CORS for local development |
| 24 | +app.add_middleware( |
| 25 | + CORSMiddleware, |
| 26 | + allow_origins=["http://localhost:5173", "http://localhost:3000"], # Vite dev server |
| 27 | + allow_credentials=True, |
| 28 | + allow_methods=["*"], |
| 29 | + allow_headers=["*"], |
| 30 | +) |
| 31 | + |
| 32 | +# Initialize model inference |
| 33 | +model_inference = XRDModelInference() |
| 34 | + |
| 35 | + |
| 36 | +@app.on_event("startup") |
| 37 | +async def startup_event(): |
| 38 | + """Load model on startup""" |
| 39 | + model_inference.load_model() |
| 40 | + |
| 41 | + |
| 42 | +@app.get("/api/health") |
| 43 | +async def health_check(): |
| 44 | + """Health check endpoint""" |
| 45 | + return {"status": "healthy", "model_loaded": model_inference.is_loaded()} |
| 46 | + |
| 47 | + |
| 48 | +@app.post("/api/predict") |
| 49 | +async def predict(data: dict): |
| 50 | + """ |
| 51 | + Predict XRD analysis from preprocessed data. |
| 52 | + |
| 53 | + Expects JSON payload: {"x": [2theta values], "y": [intensity values], "metadata": {...}} |
| 54 | + Returns: {"predictions": [...], "confidence": float} |
| 55 | + """ |
| 56 | + import time |
| 57 | + request_start = time.time() |
| 58 | + |
| 59 | + try: |
| 60 | + # Extract metadata if present |
| 61 | + metadata = data.get("metadata", {}) |
| 62 | + request_id = metadata.get("timestamp", "unknown") |
| 63 | + filename = metadata.get("filename", "unknown") |
| 64 | + analysis_count = metadata.get("analysisCount", "unknown") |
| 65 | + |
| 66 | + x = data.get("x", []) |
| 67 | + y = data.get("y", []) |
| 68 | + |
| 69 | + if not x or not y: |
| 70 | + return JSONResponse( |
| 71 | + status_code=400, |
| 72 | + content={"error": "Missing x or y data"} |
| 73 | + ) |
| 74 | + |
| 75 | + if len(x) != len(y): |
| 76 | + return JSONResponse( |
| 77 | + status_code=400, |
| 78 | + content={"error": "x and y arrays must have the same length"} |
| 79 | + ) |
| 80 | + |
| 81 | + # Run inference |
| 82 | + results = model_inference.predict(x, y) |
| 83 | + |
| 84 | + request_time = (time.time() - request_start) * 1000 # Convert to ms |
| 85 | + |
| 86 | + # Add request tracking to response |
| 87 | + if isinstance(results, dict): |
| 88 | + results['request_metadata'] = { |
| 89 | + 'request_id': request_id, |
| 90 | + 'filename': filename, |
| 91 | + 'analysis_count': analysis_count, |
| 92 | + 'processing_time_ms': request_time |
| 93 | + } |
| 94 | + |
| 95 | + # Return with anti-caching headers |
| 96 | + return JSONResponse( |
| 97 | + content=results, |
| 98 | + headers={ |
| 99 | + 'Cache-Control': 'no-cache, no-store, must-revalidate, private', |
| 100 | + 'Pragma': 'no-cache', |
| 101 | + 'Expires': '0', |
| 102 | + 'X-Request-ID': str(request_id), |
| 103 | + } |
| 104 | + ) |
| 105 | + |
| 106 | + except Exception as e: |
| 107 | + return JSONResponse( |
| 108 | + status_code=500, |
| 109 | + content={"error": f"Prediction failed: {str(e)}"} |
| 110 | + ) |
| 111 | + |
| 112 | + |
| 113 | +# Static files and SPA support |
| 114 | +frontend_dist = Path(__file__).parent.parent / "frontend" / "dist" |
| 115 | + |
| 116 | +if frontend_dist.exists(): |
| 117 | + # Mount static assets |
| 118 | + app.mount("/assets", StaticFiles(directory=str(frontend_dist / "assets")), name="assets") |
| 119 | + |
| 120 | + # Catch-all route for React Router (SPA) |
| 121 | + @app.get("/{path:path}") |
| 122 | + async def serve_spa(path: str): |
| 123 | + """Serve React SPA""" |
| 124 | + # Check if file exists in dist |
| 125 | + file_path = frontend_dist / path |
| 126 | + if file_path.is_file(): |
| 127 | + return FileResponse(file_path) |
| 128 | + # Otherwise serve index.html for client-side routing |
| 129 | + return FileResponse(frontend_dist / "index.html") |
| 130 | +else: |
| 131 | + @app.get("/") |
| 132 | + async def root(): |
| 133 | + return {"message": "Frontend not built. Run 'npm run build' in frontend/"} |
0 commit comments