From 41012e2d5ebc82e4c62f7a379590f6cbfa05d002 Mon Sep 17 00:00:00 2001 From: kaikai-macbook <872735722@qq.com> Date: Mon, 1 Jun 2026 16:44:49 +0800 Subject: [PATCH] Support Qwen chat as optimizer backend --- README.md | 15 ++ docs/reference/config.md | 13 ++ scripts/train.py | 30 ++++ skillopt/config.py | 12 ++ skillopt/engine/trainer.py | 28 ++- skillopt/model/__init__.py | 49 ++++++ skillopt/model/azure_openai.py | 5 +- skillopt/model/backend_config.py | 6 +- skillopt/model/qwen_backend.py | 282 +++++++++++++++++++++++++------ 9 files changed, 375 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 54f9a75..ce631dc 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,21 @@ export QWEN_CHAT_BASE_URL="http://localhost:8000/v1" export QWEN_CHAT_MODEL="Qwen/Qwen3.5-4B" ``` +`qwen_chat` can also be used as the optimizer backend. When optimizer and +target should point to different local vLLM services, use the role-specific +settings: + +```bash +python scripts/train.py \ + --config configs/searchqa/default.yaml \ + --optimizer_backend qwen_chat \ + --target_backend qwen_chat \ + --optimizer_model Qwen/Qwen3.5-4B \ + --target_model Qwen/Qwen3.5-4B \ + --optimizer_qwen_chat_base_url http://localhost:8001/v1 \ + --target_qwen_chat_base_url http://localhost:8000/v1 +``` + #### MiniMax ```bash diff --git a/docs/reference/config.md b/docs/reference/config.md index eec0472..0b39bd0 100644 --- a/docs/reference/config.md +++ b/docs/reference/config.md @@ -10,6 +10,12 @@ Complete reference for all SkillOpt configuration parameters. | `model.optimizer` | str | `gpt-5.5` | Optimizer model (for reflection & slow update) | | `model.target` | str | `gpt-5.5` | Target model (for rollout execution) | | `model.reasoning_effort` | str | `medium` | Reasoning effort level | +| `model.optimizer_backend` | str | `openai_chat` | Optimizer backend: `openai_chat` / `claude_chat` / `qwen_chat` / `minimax_chat` | +| `model.target_backend` | str | `openai_chat` | Target backend: chat backends plus execution harnesses | +| `model.qwen_chat_base_url` | str | `http://localhost:8000/v1` | Shared Qwen/vLLM OpenAI-compatible endpoint | +| `model.qwen_chat_enable_thinking` | bool | `false` | Shared Qwen thinking flag | +| `model.optimizer_qwen_chat_base_url` | str | — | Optimizer-specific Qwen/vLLM endpoint; overrides shared `qwen_chat_base_url` | +| `model.target_qwen_chat_base_url` | str | — | Target-specific Qwen/vLLM endpoint; overrides shared `qwen_chat_base_url` | ## Training (`train`) @@ -70,3 +76,10 @@ Complete reference for all SkillOpt configuration parameters. | `AZURE_OPENAI_API_KEY` / `model.azure_openai_api_key` | Azure API key | | `OPENAI_API_KEY` | OpenAI API key (for `openai_chat` backend) | | `ANTHROPIC_API_KEY` | Anthropic API key (for `claude_code_exec` backend) | +| `QWEN_CHAT_BASE_URL` | Shared local vLLM endpoint for `qwen_chat` | +| `QWEN_CHAT_MODEL` | Shared served model name for `qwen_chat` | +| `QWEN_CHAT_API_KEY` | Optional API key for the shared Qwen endpoint | +| `OPTIMIZER_QWEN_CHAT_BASE_URL` | Optimizer-specific local vLLM endpoint | +| `OPTIMIZER_QWEN_CHAT_MODEL` | Optimizer-specific served model name | +| `TARGET_QWEN_CHAT_BASE_URL` | Target-specific local vLLM endpoint | +| `TARGET_QWEN_CHAT_MODEL` | Target-specific served model name | diff --git a/scripts/train.py b/scripts/train.py index d4acce6..c16474b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -173,6 +173,18 @@ def parse_args() -> argparse.Namespace: p.add_argument("--qwen_chat_timeout_seconds", type=float) p.add_argument("--qwen_chat_max_tokens", type=int) p.add_argument("--qwen_chat_enable_thinking", type=_BOOL) + p.add_argument("--optimizer_qwen_chat_base_url", type=str) + p.add_argument("--optimizer_qwen_chat_api_key", type=str) + p.add_argument("--optimizer_qwen_chat_temperature", type=float) + p.add_argument("--optimizer_qwen_chat_timeout_seconds", type=float) + p.add_argument("--optimizer_qwen_chat_max_tokens", type=int) + p.add_argument("--optimizer_qwen_chat_enable_thinking", type=_BOOL) + p.add_argument("--target_qwen_chat_base_url", type=str) + p.add_argument("--target_qwen_chat_api_key", type=str) + p.add_argument("--target_qwen_chat_temperature", type=float) + p.add_argument("--target_qwen_chat_timeout_seconds", type=float) + p.add_argument("--target_qwen_chat_max_tokens", type=int) + p.add_argument("--target_qwen_chat_enable_thinking", type=_BOOL) p.add_argument("--minimax_base_url", type=str) p.add_argument("--minimax_api_key", type=str) p.add_argument("--minimax_model", type=str) @@ -295,6 +307,18 @@ def parse_args() -> argparse.Namespace: "qwen_chat_timeout_seconds": "model.qwen_chat_timeout_seconds", "qwen_chat_max_tokens": "model.qwen_chat_max_tokens", "qwen_chat_enable_thinking": "model.qwen_chat_enable_thinking", + "optimizer_qwen_chat_base_url": "model.optimizer_qwen_chat_base_url", + "optimizer_qwen_chat_api_key": "model.optimizer_qwen_chat_api_key", + "optimizer_qwen_chat_temperature": "model.optimizer_qwen_chat_temperature", + "optimizer_qwen_chat_timeout_seconds": "model.optimizer_qwen_chat_timeout_seconds", + "optimizer_qwen_chat_max_tokens": "model.optimizer_qwen_chat_max_tokens", + "optimizer_qwen_chat_enable_thinking": "model.optimizer_qwen_chat_enable_thinking", + "target_qwen_chat_base_url": "model.target_qwen_chat_base_url", + "target_qwen_chat_api_key": "model.target_qwen_chat_api_key", + "target_qwen_chat_temperature": "model.target_qwen_chat_temperature", + "target_qwen_chat_timeout_seconds": "model.target_qwen_chat_timeout_seconds", + "target_qwen_chat_max_tokens": "model.target_qwen_chat_max_tokens", + "target_qwen_chat_enable_thinking": "model.target_qwen_chat_enable_thinking", "minimax_base_url": "model.minimax_base_url", "minimax_api_key": "model.minimax_api_key", "minimax_model": "model.minimax_model", @@ -431,6 +455,12 @@ def _has_model_override(dotted_key: str, legacy_key: str) -> bool: and not _has_model_override("model.optimizer", "optimizer_model") ): flat["optimizer_model"] = default_model_for_backend("claude_chat") + if flat.get("optimizer_backend") == "qwen_chat": + if ( + str(flat.get("optimizer_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS + and not _has_model_override("model.optimizer", "optimizer_model") + ): + flat["optimizer_model"] = default_model_for_backend("qwen_chat") if flat.get("target_backend") == "claude_chat": if ( str(flat.get("target_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS diff --git a/skillopt/config.py b/skillopt/config.py index 211d020..5962a05 100644 --- a/skillopt/config.py +++ b/skillopt/config.py @@ -79,6 +79,18 @@ "model.qwen_chat_timeout_seconds": "qwen_chat_timeout_seconds", "model.qwen_chat_max_tokens": "qwen_chat_max_tokens", "model.qwen_chat_enable_thinking": "qwen_chat_enable_thinking", + "model.optimizer_qwen_chat_base_url": "optimizer_qwen_chat_base_url", + "model.optimizer_qwen_chat_api_key": "optimizer_qwen_chat_api_key", + "model.optimizer_qwen_chat_temperature": "optimizer_qwen_chat_temperature", + "model.optimizer_qwen_chat_timeout_seconds": "optimizer_qwen_chat_timeout_seconds", + "model.optimizer_qwen_chat_max_tokens": "optimizer_qwen_chat_max_tokens", + "model.optimizer_qwen_chat_enable_thinking": "optimizer_qwen_chat_enable_thinking", + "model.target_qwen_chat_base_url": "target_qwen_chat_base_url", + "model.target_qwen_chat_api_key": "target_qwen_chat_api_key", + "model.target_qwen_chat_temperature": "target_qwen_chat_temperature", + "model.target_qwen_chat_timeout_seconds": "target_qwen_chat_timeout_seconds", + "model.target_qwen_chat_max_tokens": "target_qwen_chat_max_tokens", + "model.target_qwen_chat_enable_thinking": "target_qwen_chat_enable_thinking", "model.minimax_base_url": "minimax_base_url", "model.minimax_api_key": "minimax_api_key", "model.minimax_model": "minimax_model", diff --git a/skillopt/engine/trainer.py b/skillopt/engine/trainer.py index 8c887e6..9559acb 100644 --- a/skillopt/engine/trainer.py +++ b/skillopt/engine/trainer.py @@ -629,14 +629,26 @@ def _build_eval_env(split: str, env_num: int, seed: int): effort=cfg.get("claude_code_exec_effort", cfg.get("reasoning_effort", "medium")), max_thinking_tokens=cfg.get("claude_code_exec_max_thinking_tokens", 16384), ) - configure_qwen_chat( - base_url=cfg.get("qwen_chat_base_url") or None, - api_key=cfg.get("qwen_chat_api_key") or None, - temperature=cfg.get("qwen_chat_temperature"), - timeout_seconds=cfg.get("qwen_chat_timeout_seconds"), - max_tokens=cfg.get("qwen_chat_max_tokens"), - enable_thinking=cfg.get("qwen_chat_enable_thinking"), - ) + configure_qwen_chat( + base_url=cfg.get("qwen_chat_base_url") or None, + api_key=cfg.get("qwen_chat_api_key") or None, + temperature=cfg.get("qwen_chat_temperature"), + timeout_seconds=cfg.get("qwen_chat_timeout_seconds"), + max_tokens=cfg.get("qwen_chat_max_tokens"), + enable_thinking=cfg.get("qwen_chat_enable_thinking"), + optimizer_base_url=cfg.get("optimizer_qwen_chat_base_url") or None, + optimizer_api_key=cfg.get("optimizer_qwen_chat_api_key") or None, + optimizer_temperature=cfg.get("optimizer_qwen_chat_temperature"), + optimizer_timeout_seconds=cfg.get("optimizer_qwen_chat_timeout_seconds"), + optimizer_max_tokens=cfg.get("optimizer_qwen_chat_max_tokens"), + optimizer_enable_thinking=cfg.get("optimizer_qwen_chat_enable_thinking"), + target_base_url=cfg.get("target_qwen_chat_base_url") or None, + target_api_key=cfg.get("target_qwen_chat_api_key") or None, + target_temperature=cfg.get("target_qwen_chat_temperature"), + target_timeout_seconds=cfg.get("target_qwen_chat_timeout_seconds"), + target_max_tokens=cfg.get("target_qwen_chat_max_tokens"), + target_enable_thinking=cfg.get("target_qwen_chat_enable_thinking"), + ) configure_minimax_chat( base_url=cfg.get("minimax_base_url") or None, api_key=cfg.get("minimax_api_key") or None, diff --git a/skillopt/model/__init__.py b/skillopt/model/__init__.py index cbd5358..6730ab3 100644 --- a/skillopt/model/__init__.py +++ b/skillopt/model/__init__.py @@ -64,6 +64,8 @@ def get_backend_name() -> str: target = get_target_backend() if optimizer == "claude_chat" and target == "claude_chat": return "claude_chat" + if optimizer == "qwen_chat" and target == "qwen_chat": + return "qwen_chat" if optimizer == "openai_chat" and target == "openai_chat": return "azure_openai" if optimizer == "openai_chat" and target == "codex_exec": @@ -93,6 +95,16 @@ def chat_optimizer( stage=stage, timeout=timeout, ) + if get_optimizer_backend() == "qwen_chat": + return _qwen.chat_optimizer( + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + timeout=timeout, + ) return _openai.chat_optimizer( system=system, user=user, @@ -179,6 +191,18 @@ def chat_optimizer_messages( return_message=return_message, timeout=timeout, ) + if get_optimizer_backend() == "qwen_chat": + return _qwen.chat_optimizer_messages( + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) return _openai.chat_optimizer_messages( messages=messages, max_completion_tokens=max_completion_tokens, @@ -414,6 +438,18 @@ def configure_qwen_chat( timeout_seconds: float | str | None = None, max_tokens: int | str | None = None, enable_thinking: bool | str | None = None, + optimizer_base_url: str | None = None, + optimizer_api_key: str | None = None, + optimizer_temperature: float | str | None = None, + optimizer_timeout_seconds: float | str | None = None, + optimizer_max_tokens: int | str | None = None, + optimizer_enable_thinking: bool | str | None = None, + target_base_url: str | None = None, + target_api_key: str | None = None, + target_temperature: float | str | None = None, + target_timeout_seconds: float | str | None = None, + target_max_tokens: int | str | None = None, + target_enable_thinking: bool | str | None = None, ) -> None: _qwen.configure_qwen_chat( base_url=base_url, @@ -422,6 +458,18 @@ def configure_qwen_chat( timeout_seconds=timeout_seconds, max_tokens=max_tokens, enable_thinking=enable_thinking, + optimizer_base_url=optimizer_base_url, + optimizer_api_key=optimizer_api_key, + optimizer_temperature=optimizer_temperature, + optimizer_timeout_seconds=optimizer_timeout_seconds, + optimizer_max_tokens=optimizer_max_tokens, + optimizer_enable_thinking=optimizer_enable_thinking, + target_base_url=target_base_url, + target_api_key=target_api_key, + target_temperature=target_temperature, + target_timeout_seconds=target_timeout_seconds, + target_max_tokens=target_max_tokens, + target_enable_thinking=target_enable_thinking, ) @@ -461,3 +509,4 @@ def set_target_deployment(deployment: str) -> None: def set_optimizer_deployment(deployment: str) -> None: _openai.set_optimizer_deployment(deployment) _claude.set_optimizer_deployment(deployment) + _qwen.set_optimizer_deployment(deployment) diff --git a/skillopt/model/azure_openai.py b/skillopt/model/azure_openai.py index 247e7dd..e7c139c 100644 --- a/skillopt/model/azure_openai.py +++ b/skillopt/model/azure_openai.py @@ -336,9 +336,10 @@ def get_target_client() -> AzureOpenAI | OpenAI: from skillopt.model.backend_config import get_target_backend if get_target_backend() == "qwen_chat": from skillopt.model import qwen_backend as _qwen + target_config = _qwen.TARGET_CONFIG _target_client = OpenAI( - base_url=_qwen.BASE_URL, - api_key=_qwen.API_KEY or "dummy", + base_url=target_config.base_url, + api_key=target_config.api_key or "dummy", ) else: _target_client = _make_client("target") diff --git a/skillopt/model/backend_config.py b/skillopt/model/backend_config.py index 2cdc8c3..f23725c 100644 --- a/skillopt/model/backend_config.py +++ b/skillopt/model/backend_config.py @@ -49,10 +49,10 @@ def _parse_int(value: str | None, default: int) -> int: def set_optimizer_backend(backend: str) -> None: global OPTIMIZER_BACKEND OPTIMIZER_BACKEND = normalize_backend_name(backend or "openai_chat") - if OPTIMIZER_BACKEND not in {"openai_chat", "claude_chat", "minimax_chat"}: + if OPTIMIZER_BACKEND not in {"openai_chat", "claude_chat", "qwen_chat", "minimax_chat"}: raise ValueError( f"Unsupported optimizer backend: {OPTIMIZER_BACKEND!r}. " - "Supported values are 'openai_chat', 'claude_chat', and 'minimax_chat'." + "Supported values are 'openai_chat', 'claude_chat', 'qwen_chat', and 'minimax_chat'." ) os.environ["OPTIMIZER_BACKEND"] = OPTIMIZER_BACKEND @@ -81,7 +81,7 @@ def is_target_exec_backend() -> bool: def is_optimizer_chat_backend() -> bool: - return OPTIMIZER_BACKEND in {"openai_chat", "claude_chat", "minimax_chat"} + return OPTIMIZER_BACKEND in {"openai_chat", "claude_chat", "qwen_chat", "minimax_chat"} def is_target_chat_backend() -> bool: diff --git a/skillopt/model/qwen_backend.py b/skillopt/model/qwen_backend.py index 6184196..be193d4 100644 --- a/skillopt/model/qwen_backend.py +++ b/skillopt/model/qwen_backend.py @@ -1,6 +1,7 @@ -"""OpenAI-compatible Qwen chat backend for the target path.""" +"""OpenAI-compatible Qwen chat backend for optimizer and target paths.""" from __future__ import annotations +from dataclasses import dataclass import json import os import threading @@ -17,32 +18,72 @@ default_model_for_backend, ) -BASE_URL = os.environ.get("QWEN_CHAT_BASE_URL", "http://localhost:8000/v1") -API_KEY = os.environ.get("QWEN_CHAT_API_KEY", "") -TIMEOUT_SECONDS = float(os.environ.get("QWEN_CHAT_TIMEOUT_SECONDS", "300") or 300) -MAX_TOKENS = int(os.environ.get("QWEN_CHAT_MAX_TOKENS", "8000") or 8000) -TEMPERATURE: float | None = None -_raw_temperature = os.environ.get("QWEN_CHAT_TEMPERATURE", "0.7").strip() -if _raw_temperature: - TEMPERATURE = float(_raw_temperature) -ENABLE_THINKING = os.environ.get("QWEN_CHAT_ENABLE_THINKING", "false").strip().lower() in { - "1", - "true", - "yes", - "on", -} - -TARGET_DEPLOYMENT = os.environ.get( - "TARGET_DEPLOYMENT", - default_model_for_backend("qwen_chat"), -) + +@dataclass +class QwenChatConfig: + base_url: str + api_key: str + timeout_seconds: float + max_tokens: int + temperature: float | None + enable_thinking: bool + deployment: str + + +def _parse_bool(value: Any, default: bool = False) -> bool: + if value is None: + return default + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def _parse_optional_float(value: Any) -> float | None: + if value is None: + return None + raw = str(value).strip() + return float(raw) if raw else None + + +def _parse_int(value: Any, default: int) -> int: + if value is None: + return default + raw = str(value).strip() + return int(raw) if raw else default + + +def _role_env(role: str, key: str, default: str) -> str: + role_key = f"{role.upper()}_QWEN_CHAT_{key}" + generic_key = f"QWEN_CHAT_{key}" + return os.environ.get(role_key) or os.environ.get(generic_key) or default + + +def _initial_config(role: str) -> QwenChatConfig: + role_upper = role.upper() + deployment_env = "OPTIMIZER_DEPLOYMENT" if role == "optimizer" else "TARGET_DEPLOYMENT" + return QwenChatConfig( + base_url=_role_env(role, "BASE_URL", "http://localhost:8000/v1"), + api_key=_role_env(role, "API_KEY", ""), + timeout_seconds=float(_role_env(role, "TIMEOUT_SECONDS", "300") or 300), + max_tokens=_parse_int(_role_env(role, "MAX_TOKENS", "8000"), 8000), + temperature=_parse_optional_float(_role_env(role, "TEMPERATURE", "0.7")), + enable_thinking=_parse_bool(_role_env(role, "ENABLE_THINKING", "false")), + deployment=( + os.environ.get(f"{role_upper}_QWEN_CHAT_MODEL") + or os.environ.get("QWEN_CHAT_MODEL") + or os.environ.get(deployment_env) + or default_model_for_backend("qwen_chat") + ), + ) + + +OPTIMIZER_CONFIG = _initial_config("optimizer") +TARGET_CONFIG = _initial_config("target") _config_lock = threading.Lock() tracker = TokenTracker() -def _chat_url() -> str: - base = BASE_URL.rstrip("/") +def _chat_url(config: QwenChatConfig) -> str: + base = config.base_url.rstrip("/") if base.endswith("/chat/completions"): return base return f"{base}/chat/completions" @@ -103,18 +144,22 @@ def _compat_message_from_payload(message: dict[str, Any], choice: dict[str, Any] ) -def _post_chat_completion(payload: dict[str, Any], timeout: float | None) -> dict[str, Any]: +def _post_chat_completion( + payload: dict[str, Any], + timeout: float | None, + config: QwenChatConfig, +) -> dict[str, Any]: headers = {"Content-Type": "application/json"} - if API_KEY: - headers["Authorization"] = f"Bearer {API_KEY}" + if config.api_key: + headers["Authorization"] = f"Bearer {config.api_key}" req = urllib.request.Request( - _chat_url(), + _chat_url(config), data=json.dumps(payload, ensure_ascii=False).encode("utf-8"), headers=headers, method="POST", ) try: - with urllib.request.urlopen(req, timeout=timeout or TIMEOUT_SECONDS) as resp: + with urllib.request.urlopen(req, timeout=timeout or config.timeout_seconds) as resp: raw = resp.read().decode("utf-8") except urllib.error.HTTPError as e: body = e.read().decode("utf-8", errors="replace") @@ -133,20 +178,22 @@ def _chat_messages_impl( retries: int, stage: str, *, + role: str, tools: list[dict[str, Any]] | None = None, tool_choice: str | dict[str, Any] | None = None, return_message: bool = False, deployment: str | None = None, timeout: float | None = None, ) -> tuple[Any, dict[str, int]]: + config = OPTIMIZER_CONFIG if role == "optimizer" else TARGET_CONFIG payload: dict[str, Any] = { - "model": deployment or TARGET_DEPLOYMENT, + "model": deployment or config.deployment, "messages": _json_safe(messages), - "max_tokens": min(max_completion_tokens, MAX_TOKENS), + "max_tokens": min(max_completion_tokens, config.max_tokens), } - payload["chat_template_kwargs"] = {"enable_thinking": ENABLE_THINKING} - if TEMPERATURE is not None: - payload["temperature"] = TEMPERATURE + payload["chat_template_kwargs"] = {"enable_thinking": config.enable_thinking} + if config.temperature is not None: + payload["temperature"] = config.temperature if tools: payload["tools"] = _json_safe(tools) if tool_choice is not None: @@ -155,7 +202,7 @@ def _chat_messages_impl( last_err: Exception | None = None for attempt in range(retries): try: - data = _post_chat_completion(payload, timeout) + data = _post_chat_completion(payload, timeout, config) choices = data.get("choices") or [] if not choices: raise RuntimeError(f"Qwen chat API returned no choices: {data}") @@ -183,35 +230,134 @@ def configure_qwen_chat( timeout_seconds: float | str | None = None, max_tokens: int | str | None = None, enable_thinking: bool | str | None = None, + optimizer_base_url: str | None = None, + optimizer_api_key: str | None = None, + optimizer_temperature: float | str | None = None, + optimizer_timeout_seconds: float | str | None = None, + optimizer_max_tokens: int | str | None = None, + optimizer_enable_thinking: bool | str | None = None, + target_base_url: str | None = None, + target_api_key: str | None = None, + target_temperature: float | str | None = None, + target_timeout_seconds: float | str | None = None, + target_max_tokens: int | str | None = None, + target_enable_thinking: bool | str | None = None, ) -> None: - global BASE_URL, API_KEY, TEMPERATURE, TIMEOUT_SECONDS, MAX_TOKENS, ENABLE_THINKING with _config_lock: if base_url is not None: - BASE_URL = str(base_url).strip() or BASE_URL - os.environ["QWEN_CHAT_BASE_URL"] = BASE_URL + os.environ["QWEN_CHAT_BASE_URL"] = str(base_url).strip() if api_key is not None: - API_KEY = str(api_key).strip() - os.environ["QWEN_CHAT_API_KEY"] = API_KEY + os.environ["QWEN_CHAT_API_KEY"] = str(api_key).strip() if temperature is not None: - raw = str(temperature).strip() - TEMPERATURE = float(raw) if raw else None - os.environ["QWEN_CHAT_TEMPERATURE"] = raw + os.environ["QWEN_CHAT_TEMPERATURE"] = str(temperature).strip() if timeout_seconds is not None: - TIMEOUT_SECONDS = float(timeout_seconds) os.environ["QWEN_CHAT_TIMEOUT_SECONDS"] = str(timeout_seconds) if max_tokens is not None: - MAX_TOKENS = int(max_tokens) os.environ["QWEN_CHAT_MAX_TOKENS"] = str(max_tokens) if enable_thinking is not None: - if isinstance(enable_thinking, str): - ENABLE_THINKING = enable_thinking.strip().lower() in {"1", "true", "yes", "on"} - else: - ENABLE_THINKING = bool(enable_thinking) - os.environ["QWEN_CHAT_ENABLE_THINKING"] = "true" if ENABLE_THINKING else "false" + os.environ["QWEN_CHAT_ENABLE_THINKING"] = ( + "true" if _parse_bool(enable_thinking) else "false" + ) + _update_config( + OPTIMIZER_CONFIG, + "optimizer", + base_url=optimizer_base_url if optimizer_base_url is not None else base_url, + api_key=optimizer_api_key if optimizer_api_key is not None else api_key, + temperature=( + optimizer_temperature + if optimizer_temperature is not None + else temperature + ), + timeout_seconds=( + optimizer_timeout_seconds + if optimizer_timeout_seconds is not None + else timeout_seconds + ), + max_tokens=optimizer_max_tokens if optimizer_max_tokens is not None else max_tokens, + enable_thinking=( + optimizer_enable_thinking + if optimizer_enable_thinking is not None + else enable_thinking + ), + ) + _update_config( + TARGET_CONFIG, + "target", + base_url=target_base_url if target_base_url is not None else base_url, + api_key=target_api_key if target_api_key is not None else api_key, + temperature=target_temperature if target_temperature is not None else temperature, + timeout_seconds=( + target_timeout_seconds + if target_timeout_seconds is not None + else timeout_seconds + ), + max_tokens=target_max_tokens if target_max_tokens is not None else max_tokens, + enable_thinking=( + target_enable_thinking + if target_enable_thinking is not None + else enable_thinking + ), + ) + + +def _update_config( + config: QwenChatConfig, + role: str, + *, + base_url: str | None = None, + api_key: str | None = None, + temperature: float | str | None = None, + timeout_seconds: float | str | None = None, + max_tokens: int | str | None = None, + enable_thinking: bool | str | None = None, +) -> None: + env_prefix = role.upper() + if base_url is not None: + config.base_url = str(base_url).strip() or config.base_url + os.environ[f"{env_prefix}_QWEN_CHAT_BASE_URL"] = config.base_url + if api_key is not None: + config.api_key = str(api_key).strip() + os.environ[f"{env_prefix}_QWEN_CHAT_API_KEY"] = config.api_key + if temperature is not None: + raw = str(temperature).strip() + config.temperature = float(raw) if raw else None + os.environ[f"{env_prefix}_QWEN_CHAT_TEMPERATURE"] = raw + if timeout_seconds is not None: + config.timeout_seconds = float(timeout_seconds) + os.environ[f"{env_prefix}_QWEN_CHAT_TIMEOUT_SECONDS"] = str(timeout_seconds) + if max_tokens is not None: + config.max_tokens = int(max_tokens) + os.environ[f"{env_prefix}_QWEN_CHAT_MAX_TOKENS"] = str(max_tokens) + if enable_thinking is not None: + config.enable_thinking = _parse_bool(enable_thinking) + os.environ[f"{env_prefix}_QWEN_CHAT_ENABLE_THINKING"] = ( + "true" if config.enable_thinking else "false" + ) def get_max_tokens() -> int: - return MAX_TOKENS + return TARGET_CONFIG.max_tokens + + +def chat_optimizer( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "optimizer", + reasoning_effort: str | None = None, + timeout: float | None = None, +) -> tuple[str, dict[str, int]]: + del reasoning_effort + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + return _chat_messages_impl( + messages, + max_completion_tokens, + retries, + stage, + role="optimizer", + timeout=timeout, + ) def chat_target( @@ -230,6 +376,33 @@ def chat_target( max_completion_tokens, retries, stage, + role="target", + timeout=timeout, + ) + + +def chat_optimizer_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "optimizer", + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: float | None = None, +) -> tuple[Any, dict[str, int]]: + del reasoning_effort + return _chat_messages_impl( + messages, + max_completion_tokens, + retries, + stage, + role="optimizer", + tools=tools, + tool_choice=tool_choice, + return_message=return_message, timeout=timeout, ) @@ -252,6 +425,7 @@ def chat_target_messages( max_completion_tokens, retries, stage, + role="target", tools=tools, tool_choice=tool_choice, return_message=return_message, @@ -272,6 +446,10 @@ def set_reasoning_effort(effort: str | None) -> None: def set_target_deployment(deployment: str) -> None: - global TARGET_DEPLOYMENT - TARGET_DEPLOYMENT = deployment or default_model_for_backend("qwen_chat") - os.environ["TARGET_DEPLOYMENT"] = TARGET_DEPLOYMENT + TARGET_CONFIG.deployment = deployment or default_model_for_backend("qwen_chat") + os.environ["TARGET_DEPLOYMENT"] = TARGET_CONFIG.deployment + + +def set_optimizer_deployment(deployment: str) -> None: + OPTIMIZER_CONFIG.deployment = deployment or default_model_for_backend("qwen_chat") + os.environ["OPTIMIZER_DEPLOYMENT"] = OPTIMIZER_CONFIG.deployment