|
| 1 | +import os |
| 2 | +import logging |
| 3 | +import asyncio |
| 4 | +import shlex |
| 5 | +from fastapi import HTTPException, status |
| 6 | +from typing import Dict, Any, List |
| 7 | + |
| 8 | +from ..models import ( |
| 9 | + LlamaCliInitRequest, |
| 10 | + LlamaCliChatRequest, |
| 11 | + BatchLlamaCliInitRequest, |
| 12 | + BatchLlamaCliRemoveRequest, |
| 13 | + BatchLlamaCliChatRequest |
| 14 | +) |
| 15 | + |
| 16 | +# Import the process management functions for persistent sessions |
| 17 | +from .process_management import ( |
| 18 | + start_cli_chat_process, |
| 19 | + send_to_cli_chat_session, |
| 20 | + terminate_cli_chat_session, |
| 21 | + cli_chat_sessions # Direct access for status checks |
| 22 | +) |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | +# This model path is used for all CLI sessions. |
| 27 | +STATIC_MODEL_PATH = "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf" |
| 28 | + |
| 29 | +async def initialize_llama_cli_session(request: LlamaCliInitRequest) -> Dict[str, Any]: |
| 30 | + """ |
| 31 | + Starts a persistent llama-cli process in conversational mode. |
| 32 | + The cli_alias from the request is used as the unique session_id. |
| 33 | + """ |
| 34 | + session_id = request.cli_alias |
| 35 | + if session_id in cli_chat_sessions: |
| 36 | + raise HTTPException( |
| 37 | + status_code=status.HTTP_409_CONFLICT, |
| 38 | + detail=f"A CLI chat session with alias '{session_id}' is already running." |
| 39 | + ) |
| 40 | + |
| 41 | + try: |
| 42 | + # Start the persistent process |
| 43 | + session_data = await start_cli_chat_process( |
| 44 | + session_id=session_id, |
| 45 | + model_path=STATIC_MODEL_PATH, |
| 46 | + threads=request.threads, |
| 47 | + ctx_size=request.ctx_size, |
| 48 | + n_predict=request.n_predict, |
| 49 | + temperature=request.temperature, |
| 50 | + repeat_penalty=request.repeat_penalty, |
| 51 | + top_k=request.top_k, |
| 52 | + top_p=request.top_p, |
| 53 | + system_prompt=request.system_prompt, |
| 54 | + ) |
| 55 | + logger.info(f"Successfully started persistent llama-cli session '{session_id}' (PID: {session_data['pid']}).") |
| 56 | + return { |
| 57 | + "cli_alias": session_id, |
| 58 | + "status": "running", |
| 59 | + "pid": session_data["pid"], |
| 60 | + "message": "CLI process started successfully in conversational mode." |
| 61 | + } |
| 62 | + except FileNotFoundError: |
| 63 | + logger.error(f"Failed to start CLI session '{session_id}': llama-cli executable not found.") |
| 64 | + raise HTTPException( |
| 65 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 66 | + detail="Llama-cli executable not found. Please ensure it's in your PATH or the LLAMA_CLI_PATH environment variable is set correctly." |
| 67 | + ) |
| 68 | + except RuntimeError as e: |
| 69 | + raise HTTPException( |
| 70 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 71 | + detail=str(e) |
| 72 | + ) |
| 73 | + except Exception as e: |
| 74 | + logger.error(f"Failed to start persistent CLI session '{session_id}': {str(e)}", exc_info=True) |
| 75 | + raise HTTPException( |
| 76 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 77 | + detail=f"An unexpected error occurred while starting the CLI process: {str(e)}" |
| 78 | + ) |
| 79 | + |
| 80 | +async def chat_with_llama_cli_session(chat_request: LlamaCliChatRequest) -> Dict[str, Any]: |
| 81 | + """ |
| 82 | + Sends a prompt to a running persistent llama-cli session. |
| 83 | + """ |
| 84 | + session_id = chat_request.cli_alias |
| 85 | + prompt = chat_request.prompt |
| 86 | + |
| 87 | + try: |
| 88 | + response_text = await send_to_cli_chat_session(session_id, prompt) |
| 89 | + return { |
| 90 | + "cli_alias": session_id, |
| 91 | + "prompt": prompt, |
| 92 | + "response": response_text |
| 93 | + } |
| 94 | + except LookupError as e: |
| 95 | + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) |
| 96 | + except (IOError, TimeoutError) as e: |
| 97 | + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(e)) |
| 98 | + except Exception as e: |
| 99 | + logger.error(f"Unexpected error during chat with session '{session_id}': {str(e)}", exc_info=True) |
| 100 | + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}") |
| 101 | + |
| 102 | +async def shutdown_llama_cli_session(cli_alias: str) -> Dict[str, str]: |
| 103 | + """ |
| 104 | + Terminates a persistent llama-cli process. |
| 105 | + """ |
| 106 | + try: |
| 107 | + message = await terminate_cli_chat_session(cli_alias) |
| 108 | + logger.info(f"Termination command for session '{cli_alias}' processed. Result: {message}") |
| 109 | + return {"cli_alias": cli_alias, "status": "terminated", "message": message} |
| 110 | + except Exception as e: |
| 111 | + logger.error(f"Failed to terminate CLI session '{cli_alias}': {str(e)}", exc_info=True) |
| 112 | + raise HTTPException( |
| 113 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 114 | + detail=f"An unexpected error occurred during termination: {str(e)}" |
| 115 | + ) |
| 116 | + |
| 117 | +async def get_llama_cli_session_status(cli_alias: str) -> Dict[str, Any]: |
| 118 | + """ |
| 119 | + Retrieves the status of a specific persistent llama-cli session. |
| 120 | + """ |
| 121 | + session_info = cli_chat_sessions.get(cli_alias) |
| 122 | + if not session_info: |
| 123 | + raise HTTPException( |
| 124 | + status_code=status.HTTP_404_NOT_FOUND, |
| 125 | + detail=f"No active CLI chat session found with alias '{cli_alias}'." |
| 126 | + ) |
| 127 | + |
| 128 | + process = session_info.get("process") |
| 129 | + status = "stopped" |
| 130 | + if process and process.returncode is None: |
| 131 | + status = "running" |
| 132 | + |
| 133 | + # Return a safe subset of the session data |
| 134 | + return { |
| 135 | + "cli_alias": cli_alias, |
| 136 | + "status": status, |
| 137 | + "pid": session_info.get("pid"), |
| 138 | + "model_path": session_info.get("model_path"), |
| 139 | + "start_time": session_info.get("start_time"), |
| 140 | + "last_interaction_time": session_info.get("last_interaction_time"), |
| 141 | + "command": " ".join(session_info.get("command", [])) |
| 142 | + } |
| 143 | + |
| 144 | +# --- Batch Operations --- |
| 145 | +async def handle_initialize_batch_llama_cli_configs(batch_request: BatchLlamaCliInitRequest) -> List[Dict[str, Any]]: |
| 146 | + """ |
| 147 | + Processes a batch request to start multiple persistent llama-cli sessions. |
| 148 | + """ |
| 149 | + aliases = [req.cli_alias for req in batch_request.requests] |
| 150 | + if len(aliases) != len(set(aliases)): |
| 151 | + raise HTTPException( |
| 152 | + status_code=status.HTTP_400_BAD_REQUEST, |
| 153 | + detail="Duplicate cli_alias values found in the batch request." |
| 154 | + ) |
| 155 | + |
| 156 | + async def process_request(req: LlamaCliInitRequest): |
| 157 | + try: |
| 158 | + result = await initialize_llama_cli_session(req) |
| 159 | + return {"cli_alias": req.cli_alias, "status": "success", "data": result} |
| 160 | + except HTTPException as e: |
| 161 | + return {"cli_alias": req.cli_alias, "status": "error", "detail": e.detail, "status_code": e.status_code} |
| 162 | + except Exception as e: |
| 163 | + logger.error(f"Unexpected error processing batch init for alias {req.cli_alias}: {str(e)}", exc_info=True) |
| 164 | + return {"cli_alias": req.cli_alias, "status": "error", "detail": "An unexpected server error occurred.", "status_code": 500} |
| 165 | + |
| 166 | + results = await asyncio.gather(*(process_request(req) for req in batch_request.requests)) |
| 167 | + return results |
| 168 | + |
| 169 | +async def handle_remove_batch_llama_cli_configs(batch_request: BatchLlamaCliRemoveRequest) -> List[Dict[str, Any]]: |
| 170 | + """ |
| 171 | + Processes a batch request to terminate multiple persistent llama-cli sessions. |
| 172 | + """ |
| 173 | + async def process_request(alias: str): |
| 174 | + try: |
| 175 | + result = await shutdown_llama_cli_session(alias) |
| 176 | + return {"cli_alias": alias, "status": "success", "data": result} |
| 177 | + except HTTPException as e: |
| 178 | + return {"cli_alias": alias, "status": "error", "detail": e.detail, "status_code": e.status_code} |
| 179 | + except Exception as e: |
| 180 | + logger.error(f"Unexpected error processing batch removal for alias {alias}: {str(e)}", exc_info=True) |
| 181 | + return {"cli_alias": alias, "status": "error", "detail": "An unexpected server error occurred.", "status_code": 500} |
| 182 | + |
| 183 | + results = await asyncio.gather(*(process_request(alias) for alias in batch_request.aliases)) |
| 184 | + return results |
| 185 | + |
| 186 | +async def handle_batch_chat_with_llama_cli(batch_request: BatchLlamaCliChatRequest) -> List[Dict[str, Any]]: |
| 187 | + """ |
| 188 | + Processes a batch of chat requests with their respective llama-cli sessions. |
| 189 | + """ |
| 190 | + async def process_request(req: LlamaCliChatRequest): |
| 191 | + try: |
| 192 | + # Reuse the single chat handler logic |
| 193 | + result = await chat_with_llama_cli_session(req) |
| 194 | + return {"cli_alias": req.cli_alias, "status": "success", "data": result} |
| 195 | + except HTTPException as e: |
| 196 | + return {"cli_alias": req.cli_alias, "status": "error", "detail": e.detail, "status_code": e.status_code} |
| 197 | + except Exception as e: |
| 198 | + logger.error(f"Unexpected error processing batch chat for alias {req.cli_alias}: {str(e)}", exc_info=True) |
| 199 | + return {"cli_alias": req.cli_alias, "status": "error", "detail": "An unexpected server error occurred.", "status_code": 500} |
| 200 | + |
| 201 | + results = await asyncio.gather(*(process_request(req) for req in batch_request.requests)) |
| 202 | + return results |
0 commit comments