|
43 | 43 | from fastdeploy.config import FDConfig |
44 | 44 | from fastdeploy.engine.register_manager import RegisterManager |
45 | 45 | from fastdeploy.engine.request import ( |
| 46 | + CompletionOutput, |
46 | 47 | ControlRequest, |
47 | 48 | ControlResponse, |
48 | 49 | Request, |
| 50 | + RequestMetrics, |
49 | 51 | RequestOutput, |
50 | 52 | RequestStatus, |
51 | 53 | RequestType, |
@@ -1500,6 +1502,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d |
1500 | 1502 |
|
1501 | 1503 | return responses |
1502 | 1504 |
|
| 1505 | + def _control_abort_requests(self, control_req: ControlRequest): |
| 1506 | + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: |
| 1507 | + raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") |
| 1508 | + args = control_req.get_args() |
| 1509 | + abort_all = args.get("abort_all", False) |
| 1510 | + req_ids = args.get("req_ids", []) |
| 1511 | + matched_input_ids = set() |
| 1512 | + now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) |
| 1513 | + |
| 1514 | + # Step 1: Determine target request list |
| 1515 | + if abort_all: |
| 1516 | + # all requests in running + waiting |
| 1517 | + target_req_ids = now_reqs |
| 1518 | + else: |
| 1519 | + # filter out requests that actually exist |
| 1520 | + target_req_ids = [] |
| 1521 | + for rid in req_ids: |
| 1522 | + if rid in now_reqs: |
| 1523 | + target_req_ids.append(rid) |
| 1524 | + matched_input_ids.add(rid) |
| 1525 | + elif f"{rid}_0" in now_reqs: |
| 1526 | + target_req_ids.append(f"{rid}_0") |
| 1527 | + matched_input_ids.add(rid) |
| 1528 | + |
| 1529 | + if not target_req_ids: |
| 1530 | + return {"aborted": [], "not_found": req_ids if not abort_all else []} |
| 1531 | + |
| 1532 | + # Step 2: Collect partial results |
| 1533 | + aborted_info = [] |
| 1534 | + results = [] |
| 1535 | + for req_id in target_req_ids: |
| 1536 | + request = self.resource_manager.requests.get(req_id) |
| 1537 | + if request is None: |
| 1538 | + scheduled_req = self.scheduler.requests.get(req_id) |
| 1539 | + if scheduled_req is None: |
| 1540 | + continue |
| 1541 | + request = scheduled_req.raw |
| 1542 | + |
| 1543 | + partial_token_ids = list(request.output_token_ids) |
| 1544 | + |
| 1545 | + # Construct finished response with partial results |
| 1546 | + now = time.time() |
| 1547 | + abort_metrics = RequestMetrics( |
| 1548 | + arrival_time=request.metrics.arrival_time if request.metrics else now, |
| 1549 | + inference_start_time=request.metrics.inference_start_time if request.metrics else now, |
| 1550 | + engine_recv_latest_token_time=now, |
| 1551 | + engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, |
| 1552 | + request_start_time=request.metrics.arrival_time if request.metrics else now, |
| 1553 | + ) |
| 1554 | + result = RequestOutput( |
| 1555 | + request_id=req_id, |
| 1556 | + finished=True, |
| 1557 | + outputs=CompletionOutput( |
| 1558 | + index=0, |
| 1559 | + send_idx=len(partial_token_ids), |
| 1560 | + token_ids=[self.data_processor.eos_token_ids[0]], |
| 1561 | + ), |
| 1562 | + metrics=abort_metrics, |
| 1563 | + error_code=200, |
| 1564 | + error_msg="Aborted", |
| 1565 | + ) |
| 1566 | + results.append(result) |
| 1567 | + aborted_info.append( |
| 1568 | + { |
| 1569 | + "request_id": req_id, |
| 1570 | + "output_token_count": len(partial_token_ids), |
| 1571 | + } |
| 1572 | + ) |
| 1573 | + |
| 1574 | + # Step 3: Execute abort — add all requests to waiting_abort_req_id_set |
| 1575 | + if envs.ENABLE_V1_KVCACHE_SCHEDULER: |
| 1576 | + for req_id in target_req_ids: |
| 1577 | + self.resource_manager.add_abort_req_ids(req_id) |
| 1578 | + time.sleep(0.0001) |
| 1579 | + if self.cfg.scheduler_config.splitwise_role != "prefill": |
| 1580 | + self._wait_abort_complete(target_req_ids) |
| 1581 | + |
| 1582 | + # Add results to scheduler, engine will have a thread calling get_results, |
| 1583 | + # then cleanup and call send_response to send to client. |
| 1584 | + # When client disconnects, send_response will automatically ignore |
| 1585 | + if self.cfg.scheduler_config.splitwise_role != "prefill": |
| 1586 | + try: |
| 1587 | + # self.send_response_server.send_response(req_id, [result]) |
| 1588 | + self.scheduler.put_results(results) |
| 1589 | + except Exception: |
| 1590 | + pass # client may have disconnected |
| 1591 | + |
| 1592 | + not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] |
| 1593 | + |
| 1594 | + return {"aborted": aborted_info, "not_found": not_found} |
| 1595 | + |
| 1596 | + def _wait_abort_complete(self, target_req_ids, stall_timeout=1): |
| 1597 | + """ |
| 1598 | + Wait for all abort requests to complete. |
| 1599 | + - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet |
| 1600 | + - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, |
| 1601 | + reset progress state if any, then continue monitoring |
| 1602 | + """ |
| 1603 | + target_set = set(target_req_ids) |
| 1604 | + prev_remaining_count = len(target_set) |
| 1605 | + last_progress_time = time.time() |
| 1606 | + remaining = target_set & self.resource_manager.get_reqs_in_aborting() |
| 1607 | + while remaining: |
| 1608 | + remaining = target_set & self.resource_manager.get_reqs_in_aborting() |
| 1609 | + if not remaining: |
| 1610 | + self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") |
| 1611 | + return |
| 1612 | + |
| 1613 | + current_count = len(remaining) |
| 1614 | + if current_count < prev_remaining_count: |
| 1615 | + # progress made: recycle_abort_task was called |
| 1616 | + self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") |
| 1617 | + last_progress_time = time.time() |
| 1618 | + prev_remaining_count = current_count |
| 1619 | + |
| 1620 | + if time.time() - last_progress_time > stall_timeout: |
| 1621 | + # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) |
| 1622 | + stuck = remaining & self.resource_manager.to_be_aborted_req_id_set |
| 1623 | + if stuck: |
| 1624 | + self.llm_logger.warning( |
| 1625 | + f"no abort progress for {stall_timeout}s, " |
| 1626 | + f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" |
| 1627 | + ) |
| 1628 | + for req_id in list(stuck): |
| 1629 | + self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") |
| 1630 | + self.resource_manager.recycle_abort_task(req_id) |
| 1631 | + # reset progress state |
| 1632 | + last_progress_time = time.time() |
| 1633 | + prev_remaining_count = current_count - len(stuck) |
| 1634 | + # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow |
| 1635 | + |
| 1636 | + time.sleep(0.005) |
| 1637 | + |
1503 | 1638 | def _parse_tags(self, control_request: ControlRequest): |
1504 | 1639 | """ |
1505 | 1640 | Parse tags from control request. |
|
0 commit comments