Skip to content

Commit b0b3128

Browse files
authored
abort requests (PaddlePaddle#6992)
1 parent b9951a7 commit b0b3128

13 files changed

Lines changed: 670 additions & 3 deletions

File tree

docs/online_serving/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,4 @@ DeltaFunctionCall:
577577
- `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset.
578578
- `/v1/resume` - Resume generation.
579579
- `/v1/is_paused` - Check if generation is paused.
580+
- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts.

docs/online_serving/router.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ The Router exposes a set of HTTP services to provide unified request scheduling,
151151
|----------|------|------|
152152
| POST | `/v1/chat/completions` | Provide scheduling services for inference requests based on the Chat Completions API |
153153
| POST | `/v1/completions` | Provide scheduling services for general text completion inference requests |
154+
| POST | `/v1/abort_requests` | Abort inference requests to release GPU memory and compute resources. Accepts `req_ids` or `abort_all=true`. Returns aborted requests with their generated token counts |
154155
| POST | `/register` | Allow inference instances to register their metadata with the Router for scheduling |
155156
| GET | `/registered` | Query the list of currently registered inference instances |
156157
| GET | `/registered_number` | Query the number of currently registered inference instances |

docs/zh/online_serving/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,4 @@ DeltaFunctionCall:
563563
/v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。
564564
/v1/resume - 恢复推理生成。
565565
/v1/is_paused - 检查推理生成是否已暂停。
566+
/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。

docs/zh/online_serving/router.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ Router 通过 HTTP 接口对外提供统一的调度服务,同时支持运行
152152
|----------|------|------|
153153
| POST | `/v1/chat/completions` | 对外提供基于 Chat 接口的推理请求调度服务 |
154154
| POST | `/v1/completions` | 对外提供通用文本补全请求的调度服务 |
155+
| POST | `/v1/abort_requests` | 中断推理请求,释放 GPU 显存和计算资源。支持传入 `req_ids``abort_all=true`,返回已中断请求列表及其已生成的 token 数 |
155156
| POST | `/register` | 推理实例向 Router 注册自身信息,用于参与调度 |
156157
| GET | `/registered` | 查询当前已注册的推理实例列表 |
157158
| GET | `/registered_number` | 查询当前已注册的推理实例数量 |

fastdeploy/engine/common_engine.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
from fastdeploy.config import FDConfig
4444
from fastdeploy.engine.register_manager import RegisterManager
4545
from fastdeploy.engine.request import (
46+
CompletionOutput,
4647
ControlRequest,
4748
ControlResponse,
4849
Request,
50+
RequestMetrics,
4951
RequestOutput,
5052
RequestStatus,
5153
RequestType,
@@ -1500,6 +1502,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d
15001502

15011503
return responses
15021504

1505+
def _control_abort_requests(self, control_req: ControlRequest):
1506+
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
1507+
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
1508+
args = control_req.get_args()
1509+
abort_all = args.get("abort_all", False)
1510+
req_ids = args.get("req_ids", [])
1511+
matched_input_ids = set()
1512+
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))
1513+
1514+
# Step 1: Determine target request list
1515+
if abort_all:
1516+
# all requests in running + waiting
1517+
target_req_ids = now_reqs
1518+
else:
1519+
# filter out requests that actually exist
1520+
target_req_ids = []
1521+
for rid in req_ids:
1522+
if rid in now_reqs:
1523+
target_req_ids.append(rid)
1524+
matched_input_ids.add(rid)
1525+
elif f"{rid}_0" in now_reqs:
1526+
target_req_ids.append(f"{rid}_0")
1527+
matched_input_ids.add(rid)
1528+
1529+
if not target_req_ids:
1530+
return {"aborted": [], "not_found": req_ids if not abort_all else []}
1531+
1532+
# Step 2: Collect partial results
1533+
aborted_info = []
1534+
results = []
1535+
for req_id in target_req_ids:
1536+
request = self.resource_manager.requests.get(req_id)
1537+
if request is None:
1538+
scheduled_req = self.scheduler.requests.get(req_id)
1539+
if scheduled_req is None:
1540+
continue
1541+
request = scheduled_req.raw
1542+
1543+
partial_token_ids = list(request.output_token_ids)
1544+
1545+
# Construct finished response with partial results
1546+
now = time.time()
1547+
abort_metrics = RequestMetrics(
1548+
arrival_time=request.metrics.arrival_time if request.metrics else now,
1549+
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
1550+
engine_recv_latest_token_time=now,
1551+
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
1552+
request_start_time=request.metrics.arrival_time if request.metrics else now,
1553+
)
1554+
result = RequestOutput(
1555+
request_id=req_id,
1556+
finished=True,
1557+
outputs=CompletionOutput(
1558+
index=0,
1559+
send_idx=len(partial_token_ids),
1560+
token_ids=[self.data_processor.eos_token_ids[0]],
1561+
),
1562+
metrics=abort_metrics,
1563+
error_code=200,
1564+
error_msg="Aborted",
1565+
)
1566+
results.append(result)
1567+
aborted_info.append(
1568+
{
1569+
"request_id": req_id,
1570+
"output_token_count": len(partial_token_ids),
1571+
}
1572+
)
1573+
1574+
# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
1575+
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
1576+
for req_id in target_req_ids:
1577+
self.resource_manager.add_abort_req_ids(req_id)
1578+
time.sleep(0.0001)
1579+
if self.cfg.scheduler_config.splitwise_role != "prefill":
1580+
self._wait_abort_complete(target_req_ids)
1581+
1582+
# Add results to scheduler, engine will have a thread calling get_results,
1583+
# then cleanup and call send_response to send to client.
1584+
# When client disconnects, send_response will automatically ignore
1585+
if self.cfg.scheduler_config.splitwise_role != "prefill":
1586+
try:
1587+
# self.send_response_server.send_response(req_id, [result])
1588+
self.scheduler.put_results(results)
1589+
except Exception:
1590+
pass # client may have disconnected
1591+
1592+
not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []
1593+
1594+
return {"aborted": aborted_info, "not_found": not_found}
1595+
1596+
def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
1597+
"""
1598+
Wait for all abort requests to complete.
1599+
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
1600+
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
1601+
reset progress state if any, then continue monitoring
1602+
"""
1603+
target_set = set(target_req_ids)
1604+
prev_remaining_count = len(target_set)
1605+
last_progress_time = time.time()
1606+
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
1607+
while remaining:
1608+
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
1609+
if not remaining:
1610+
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
1611+
return
1612+
1613+
current_count = len(remaining)
1614+
if current_count < prev_remaining_count:
1615+
# progress made: recycle_abort_task was called
1616+
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
1617+
last_progress_time = time.time()
1618+
prev_remaining_count = current_count
1619+
1620+
if time.time() - last_progress_time > stall_timeout:
1621+
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
1622+
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
1623+
if stuck:
1624+
self.llm_logger.warning(
1625+
f"no abort progress for {stall_timeout}s, "
1626+
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
1627+
)
1628+
for req_id in list(stuck):
1629+
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
1630+
self.resource_manager.recycle_abort_task(req_id)
1631+
# reset progress state
1632+
last_progress_time = time.time()
1633+
prev_remaining_count = current_count - len(stuck)
1634+
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow
1635+
1636+
time.sleep(0.005)
1637+
15031638
def _parse_tags(self, control_request: ControlRequest):
15041639
"""
15051640
Parse tags from control request.

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def recycle_abort_task(self, request_id):
282282
del self.requests[request_id]
283283
del self.req_dict[request_id]
284284
self.to_be_aborted_req_id_set.remove(request_id)
285+
self.update_metrics()
285286

286287
def _trigger_abort(self, request_id, scheduled_reqs):
287288
if request_id in self.requests:
@@ -1207,6 +1208,9 @@ def download_bos_features(bos_client, features_urls):
12071208
return None
12081209
inputs["audio_features"] = result
12091210

1211+
def get_reqs_in_aborting(self):
1212+
return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set
1213+
12101214
def get_available_position(self) -> int:
12111215
position = 0
12121216
while position < self.max_num_seqs:

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,25 @@ async def update_weights(request: Request) -> Response:
475475
return control_response.to_api_json_response()
476476

477477

478+
@app.post("/v1/abort_requests")
479+
async def abort_requests(request: Request):
480+
body = await request.json()
481+
abort_all = body.get("abort_all", False)
482+
req_ids = body.get("req_ids", None)
483+
484+
# 参数校验
485+
if not abort_all and not req_ids:
486+
return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"})
487+
488+
control_request = ControlRequest(
489+
request_id=f"control-{uuid.uuid4()}",
490+
method="abort_requests",
491+
args={"abort_all": abort_all, "req_ids": req_ids or []},
492+
)
493+
control_response = await app.state.engine_client.run_control_method(control_request)
494+
return control_response.to_api_json_response()
495+
496+
478497
def wrap_streaming_generator(original_generator: AsyncGenerator):
479498
"""
480499
Wrap an async generator to release the connection semaphore when the generator is finished.

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,9 @@ async def chat_completion_stream_generator(
469469
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
470470
choice.finish_reason = "recover_stop"
471471

472+
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
473+
choice.finish_reason = "abort"
474+
472475
inference_start_time[idx] = 0
473476

474477
if request.collect_metrics:
@@ -802,6 +805,8 @@ async def _create_chat_completion_choice(
802805
if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]:
803806
finish_reason = "recover_stop"
804807

808+
if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]:
809+
finish_reason = "abort"
805810
return ChatCompletionResponseChoice(
806811
index=idx,
807812
message=message,

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,8 @@ async def completion_stream_generator(
586586
output,
587587
tool_called[idx],
588588
)
589+
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
590+
choices[-1].finish_reason = "abort"
589591
inference_start_time[idx] = 0
590592

591593
send_idx = output.get("send_idx")
@@ -726,6 +728,8 @@ def request_output_to_completion_response(
726728
output,
727729
False,
728730
)
731+
if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]:
732+
finish_reason = "abort"
729733

730734
choice_data = CompletionResponseChoice(
731735
token_ids=token_ids,

fastdeploy/router/router.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
import aiohttp
1919
import uvicorn
20-
from fastapi import FastAPI, HTTPException
21-
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
20+
from fastapi import FastAPI, HTTPException, Request
21+
from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse
2222

2323
from fastdeploy.router.utils import (
2424
InstanceInfo,
@@ -503,6 +503,48 @@ async def health_generate():
503503
return Response(status_code=200)
504504

505505

506+
@app.post("/v1/abort_requests")
507+
async def abort_requests(request: Request):
508+
body = await request.json()
509+
prefill_servers = app.state.router.prefill_servers
510+
decode_servers = app.state.router.decode_servers
511+
all_servers = prefill_servers + decode_servers
512+
513+
async with aiohttp.ClientSession() as session:
514+
tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers]
515+
responses = await asyncio.gather(*tasks, return_exceptions=True)
516+
517+
# Aggregate results from Node D only
518+
all_aborted = []
519+
all_not_found = []
520+
errors = []
521+
decode_start = len(prefill_servers)
522+
for i, (server, resp) in enumerate(zip(all_servers, responses)):
523+
if i < decode_start:
524+
continue
525+
if isinstance(resp, Exception):
526+
errors.append({"server": server.url(), "error": str(resp)})
527+
elif resp.status == 200:
528+
data = await resp.json()
529+
result = data.get("result") or {}
530+
all_aborted.extend(result.get("aborted", []))
531+
all_not_found.extend(result.get("not_found", []))
532+
else:
533+
errors.append({"server": server.url(), "status": resp.status})
534+
535+
return JSONResponse(
536+
content={
537+
"request_id": f"router-{uuid4()}",
538+
"status": "success" if not errors else "error",
539+
"error_message": None if not errors else str(errors),
540+
"result": {
541+
"aborted": all_aborted,
542+
"not_found": list(set(all_not_found)),
543+
},
544+
}
545+
)
546+
547+
506548
def launch_router(router_args: RouterArgs):
507549
app.state.router_args = router_args
508550
print(f"Starting router with args: {router_args}")

0 commit comments

Comments
 (0)