Skip to content

Commit b1a6aae

Browse files
committed
cli reintroduction
1 parent 1dc1012 commit b1a6aae

7 files changed

Lines changed: 1235 additions & 326 deletions

File tree

app/lib/endpoints/chat_endpoints.py

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,40 @@
22
import httpx
33
import asyncio
44
import logging
5-
from .process_management import get_server_processes, get_server_configs
6-
from pydantic import BaseModel
5+
import os
6+
from pydantic import BaseModel, Field
77
from typing import List
88

9+
# --- KEPT server-related imports ---
10+
from .process_management import (
11+
get_server_processes,
12+
get_server_configs,
13+
)
14+
15+
916
logger = logging.getLogger(__name__)
1017

18+
1119
class ChatRequest(BaseModel):
1220
message: str
1321
port: int = 8081
14-
threads: int = 1
15-
ctx_size: int = 2048
16-
n_predict: int = 256
17-
temperature: float = 0.8
22+
threads: int = Field(default_factory=lambda: os.cpu_count() or 1, gt=0)
23+
ctx_size: int = Field(default=2048, gt=0)
24+
n_predict: int = Field(default=256, gt=0)
25+
temperature: float = Field(default=0.8, gt=0.0, le=2.0)
26+
27+
28+
class MultiChatRequestItem(ChatRequest):
29+
pass
30+
31+
32+
class MultiChatRequest(BaseModel):
33+
requests: List[MultiChatRequestItem]
34+
1835

19-
async def chat_with_bitnet(chat: ChatRequest):
36+
# --- Endpoint logic functions ---
37+
38+
async def handle_chat_with_bitnet_server(chat: ChatRequest):
2039
host = "127.0.0.1"
2140
key = (host, chat.port)
2241
proc_entry = get_server_processes().get(key)
@@ -27,52 +46,36 @@ async def chat_with_bitnet(chat: ChatRequest):
2746
server_url = f"http://{host}:{chat.port}/completion"
2847
payload = {
2948
"prompt": chat.message,
30-
"threads": chat.threads,
31-
"ctx_size": chat.ctx_size,
3249
"n_predict": chat.n_predict,
33-
"temperature": chat.temperature
50+
"temperature": chat.temperature,
3451
}
35-
async def _chat():
52+
53+
try:
3654
async with httpx.AsyncClient() as client:
37-
try:
38-
logger.info(f"Forwarding chat message to BitNet server on port {chat.port}.")
39-
response = await client.post(server_url, json=payload, timeout=300.0)
40-
response.raise_for_status()
41-
return response.json()
42-
except httpx.ReadTimeout:
43-
logger.error(f"ReadTimeout when communicating with BitNet server on port {chat.port}.")
44-
raise HTTPException(status_code=504, detail=f"Request to BitNet server on port {chat.port} timed out.")
45-
except httpx.ConnectError:
46-
logger.error(f"ConnectError when communicating with BitNet server on port {chat.port}.")
47-
raise HTTPException(status_code=503, detail=f"Could not connect to BitNet server on port {chat.port}.")
48-
except httpx.HTTPStatusError as e:
49-
logger.error(f"HTTPStatusError from BitNet server on port {chat.port}: {e.response.status_code} - {e.response.text}", exc_info=True)
50-
raise HTTPException(status_code=e.response.status_code, detail=f"BitNet server error: {e.response.text}")
51-
except Exception as e:
52-
logger.error(f"Unexpected error during chat with BitNet server on port {chat.port}: {str(e)}", exc_info=True)
53-
error_detail = f"An unexpected error occurred while communicating with BitNet server on port {chat.port}: {str(e)}"
54-
raise HTTPException(status_code=500, detail=error_detail)
55-
return await _chat()
55+
response = await client.post(server_url, json=payload, timeout=60.0)
56+
response.raise_for_status()
57+
response_data = response.json()
58+
# Ensure the key "content" exists before accessing it
59+
return {"response": response_data.get("content", ""), "port": chat.port}
60+
except httpx.RequestError as e:
61+
logger.error(f"HTTP request error to server {host}:{chat.port}: {e}")
62+
raise HTTPException(status_code=503, detail=f"Error communicating with BitNet server on port {chat.port}: {e}")
63+
except httpx.HTTPStatusError as e:
64+
logger.error(f"HTTP status error from server {host}:{chat.port}: {e.response.status_code} - {e.response.text}")
65+
raise HTTPException(status_code=e.response.status_code, detail=f"Error from BitNet server on port {chat.port}: {e.response.text}")
66+
except Exception as e:
67+
logger.error(f"Unexpected error during chat with server {host}:{chat.port}: {e}", exc_info=True)
68+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
5669

57-
class MultiChatRequest(BaseModel):
58-
requests: List[ChatRequest]
59-
60-
async def multichat_with_bitnet(multichat: MultiChatRequest):
61-
logger.info(f"Multichat request received for {len(multichat.requests)} chats.")
62-
async def run_chat(chat_req: ChatRequest):
63-
chat_fn = chat_with_bitnet(chat_req)
64-
return await chat_fn
65-
results = await asyncio.gather(*(run_chat(req) for req in multichat.requests), return_exceptions=True)
66-
formatted = []
67-
for i, res in enumerate(results):
68-
if isinstance(res, Exception):
69-
if isinstance(res, HTTPException):
70-
formatted.append({"error": res.detail})
71-
else:
72-
formatted.append({"error": str(res)})
73-
elif isinstance(res, dict) and "content" in res:
74-
formatted.append(res["content"])
75-
else:
76-
formatted.append(res)
77-
logger.info("Multichat processing completed.")
78-
return {"results": formatted}
70+
71+
async def handle_multichat_with_bitnet_server(data: MultiChatRequest):
72+
async def single_chat_wrapper(chat_request: MultiChatRequestItem):
73+
try:
74+
return await handle_chat_with_bitnet_server(chat_request)
75+
except HTTPException as e:
76+
return {"port": chat_request.port, "error": e.detail, "status_code": e.status_code}
77+
except Exception as e:
78+
return {"port": chat_request.port, "error": str(e), "status_code": 500}
79+
80+
results = await asyncio.gather(*[single_chat_wrapper(req) for req in data.requests])
81+
return {"results": results}
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import os
2+
import logging
3+
import asyncio
4+
import shlex
5+
from fastapi import HTTPException, status
6+
from typing import Dict, Any, List
7+
8+
from ..models import (
9+
LlamaCliInitRequest,
10+
LlamaCliChatRequest,
11+
BatchLlamaCliInitRequest,
12+
BatchLlamaCliRemoveRequest,
13+
BatchLlamaCliChatRequest
14+
)
15+
16+
# Import the process management functions for persistent sessions
17+
from .process_management import (
18+
start_cli_chat_process,
19+
send_to_cli_chat_session,
20+
terminate_cli_chat_session,
21+
cli_chat_sessions # Direct access for status checks
22+
)
23+
24+
logger = logging.getLogger(__name__)
25+
26+
# This model path is used for all CLI sessions.
27+
STATIC_MODEL_PATH = "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf"
28+
29+
async def initialize_llama_cli_session(request: LlamaCliInitRequest) -> Dict[str, Any]:
30+
"""
31+
Starts a persistent llama-cli process in conversational mode.
32+
The cli_alias from the request is used as the unique session_id.
33+
"""
34+
session_id = request.cli_alias
35+
if session_id in cli_chat_sessions:
36+
raise HTTPException(
37+
status_code=status.HTTP_409_CONFLICT,
38+
detail=f"A CLI chat session with alias '{session_id}' is already running."
39+
)
40+
41+
try:
42+
# Start the persistent process
43+
session_data = await start_cli_chat_process(
44+
session_id=session_id,
45+
model_path=STATIC_MODEL_PATH,
46+
threads=request.threads,
47+
ctx_size=request.ctx_size,
48+
n_predict=request.n_predict,
49+
temperature=request.temperature,
50+
repeat_penalty=request.repeat_penalty,
51+
top_k=request.top_k,
52+
top_p=request.top_p,
53+
system_prompt=request.system_prompt,
54+
)
55+
logger.info(f"Successfully started persistent llama-cli session '{session_id}' (PID: {session_data['pid']}).")
56+
return {
57+
"cli_alias": session_id,
58+
"status": "running",
59+
"pid": session_data["pid"],
60+
"message": "CLI process started successfully in conversational mode."
61+
}
62+
except FileNotFoundError:
63+
logger.error(f"Failed to start CLI session '{session_id}': llama-cli executable not found.")
64+
raise HTTPException(
65+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
66+
detail="Llama-cli executable not found. Please ensure it's in your PATH or the LLAMA_CLI_PATH environment variable is set correctly."
67+
)
68+
except RuntimeError as e:
69+
raise HTTPException(
70+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
71+
detail=str(e)
72+
)
73+
except Exception as e:
74+
logger.error(f"Failed to start persistent CLI session '{session_id}': {str(e)}", exc_info=True)
75+
raise HTTPException(
76+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
77+
detail=f"An unexpected error occurred while starting the CLI process: {str(e)}"
78+
)
79+
80+
async def chat_with_llama_cli_session(chat_request: LlamaCliChatRequest) -> Dict[str, Any]:
81+
"""
82+
Sends a prompt to a running persistent llama-cli session.
83+
"""
84+
session_id = chat_request.cli_alias
85+
prompt = chat_request.prompt
86+
87+
try:
88+
response_text = await send_to_cli_chat_session(session_id, prompt)
89+
return {
90+
"cli_alias": session_id,
91+
"prompt": prompt,
92+
"response": response_text
93+
}
94+
except LookupError as e:
95+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
96+
except (IOError, TimeoutError) as e:
97+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(e))
98+
except Exception as e:
99+
logger.error(f"Unexpected error during chat with session '{session_id}': {str(e)}", exc_info=True)
100+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}")
101+
102+
async def shutdown_llama_cli_session(cli_alias: str) -> Dict[str, str]:
103+
"""
104+
Terminates a persistent llama-cli process.
105+
"""
106+
try:
107+
message = await terminate_cli_chat_session(cli_alias)
108+
logger.info(f"Termination command for session '{cli_alias}' processed. Result: {message}")
109+
return {"cli_alias": cli_alias, "status": "terminated", "message": message}
110+
except Exception as e:
111+
logger.error(f"Failed to terminate CLI session '{cli_alias}': {str(e)}", exc_info=True)
112+
raise HTTPException(
113+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
114+
detail=f"An unexpected error occurred during termination: {str(e)}"
115+
)
116+
117+
async def get_llama_cli_session_status(cli_alias: str) -> Dict[str, Any]:
118+
"""
119+
Retrieves the status of a specific persistent llama-cli session.
120+
"""
121+
session_info = cli_chat_sessions.get(cli_alias)
122+
if not session_info:
123+
raise HTTPException(
124+
status_code=status.HTTP_404_NOT_FOUND,
125+
detail=f"No active CLI chat session found with alias '{cli_alias}'."
126+
)
127+
128+
process = session_info.get("process")
129+
status = "stopped"
130+
if process and process.returncode is None:
131+
status = "running"
132+
133+
# Return a safe subset of the session data
134+
return {
135+
"cli_alias": cli_alias,
136+
"status": status,
137+
"pid": session_info.get("pid"),
138+
"model_path": session_info.get("model_path"),
139+
"start_time": session_info.get("start_time"),
140+
"last_interaction_time": session_info.get("last_interaction_time"),
141+
"command": " ".join(session_info.get("command", []))
142+
}
143+
144+
# --- Batch Operations ---
145+
async def handle_initialize_batch_llama_cli_configs(batch_request: BatchLlamaCliInitRequest) -> List[Dict[str, Any]]:
146+
"""
147+
Processes a batch request to start multiple persistent llama-cli sessions.
148+
"""
149+
aliases = [req.cli_alias for req in batch_request.requests]
150+
if len(aliases) != len(set(aliases)):
151+
raise HTTPException(
152+
status_code=status.HTTP_400_BAD_REQUEST,
153+
detail="Duplicate cli_alias values found in the batch request."
154+
)
155+
156+
async def process_request(req: LlamaCliInitRequest):
157+
try:
158+
result = await initialize_llama_cli_session(req)
159+
return {"cli_alias": req.cli_alias, "status": "success", "data": result}
160+
except HTTPException as e:
161+
return {"cli_alias": req.cli_alias, "status": "error", "detail": e.detail, "status_code": e.status_code}
162+
except Exception as e:
163+
logger.error(f"Unexpected error processing batch init for alias {req.cli_alias}: {str(e)}", exc_info=True)
164+
return {"cli_alias": req.cli_alias, "status": "error", "detail": "An unexpected server error occurred.", "status_code": 500}
165+
166+
results = await asyncio.gather(*(process_request(req) for req in batch_request.requests))
167+
return results
168+
169+
async def handle_remove_batch_llama_cli_configs(batch_request: BatchLlamaCliRemoveRequest) -> List[Dict[str, Any]]:
170+
"""
171+
Processes a batch request to terminate multiple persistent llama-cli sessions.
172+
"""
173+
async def process_request(alias: str):
174+
try:
175+
result = await shutdown_llama_cli_session(alias)
176+
return {"cli_alias": alias, "status": "success", "data": result}
177+
except HTTPException as e:
178+
return {"cli_alias": alias, "status": "error", "detail": e.detail, "status_code": e.status_code}
179+
except Exception as e:
180+
logger.error(f"Unexpected error processing batch removal for alias {alias}: {str(e)}", exc_info=True)
181+
return {"cli_alias": alias, "status": "error", "detail": "An unexpected server error occurred.", "status_code": 500}
182+
183+
results = await asyncio.gather(*(process_request(alias) for alias in batch_request.aliases))
184+
return results
185+
186+
async def handle_batch_chat_with_llama_cli(batch_request: BatchLlamaCliChatRequest) -> List[Dict[str, Any]]:
187+
"""
188+
Processes a batch of chat requests with their respective llama-cli sessions.
189+
"""
190+
async def process_request(req: LlamaCliChatRequest):
191+
try:
192+
# Reuse the single chat handler logic
193+
result = await chat_with_llama_cli_session(req)
194+
return {"cli_alias": req.cli_alias, "status": "success", "data": result}
195+
except HTTPException as e:
196+
return {"cli_alias": req.cli_alias, "status": "error", "detail": e.detail, "status_code": e.status_code}
197+
except Exception as e:
198+
logger.error(f"Unexpected error processing batch chat for alias {req.cli_alias}: {str(e)}", exc_info=True)
199+
return {"cli_alias": req.cli_alias, "status": "error", "detail": "An unexpected server error occurred.", "status_code": 500}
200+
201+
results = await asyncio.gather(*(process_request(req) for req in batch_request.requests))
202+
return results

0 commit comments

Comments
 (0)