diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 85adc81a1e..dcdfbfbd17 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any, Literal, cast, overload -from openai import AsyncOpenAI, AsyncStream, Omit, omit +from openai import AsyncOpenAI, AsyncStream, NotGiven, Omit, omit from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -423,8 +423,11 @@ async def _fetch_response( "extra_body": model_settings.extra_body, "metadata": self._non_null_or_omit(model_settings.metadata), } + extra_args = model_settings.extra_args or {} duplicate_extra_arg_keys = sorted( - set(create_kwargs).intersection(model_settings.extra_args or {}) + k + for k in extra_args + if k in create_kwargs and not isinstance(create_kwargs[k], Omit | NotGiven) ) if duplicate_extra_arg_keys: if len(duplicate_extra_arg_keys) == 1: @@ -436,7 +439,7 @@ async def _fetch_response( raise TypeError( f"chat.completions.create() got multiple values for keyword arguments {keys}" ) - create_kwargs.update(model_settings.extra_args or {}) + create_kwargs.update(extra_args) ret = await self._get_client().chat.completions.create(**create_kwargs) diff --git a/tests/models/test_openai_chatcompletions.py b/tests/models/test_openai_chatcompletions.py index b2f8affd60..52e88db811 100644 --- a/tests/models/test_openai_chatcompletions.py +++ b/tests/models/test_openai_chatcompletions.py @@ -426,6 +426,28 @@ def __init__(self, completions: DummyCompletions) -> None: assert kwargs["stream_options"] is omit +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_fetch_response_allows_extra_arg_when_explicit_arg_is_omitted() -> None: + """A user-supplied extra_arg whose first-class field is only present as the + `omit` sentinel should pass through, not raise a duplicate-argument error.""" + kwargs = await _run_chat_completions_model_with_custom_base_url( + model_settings=ModelSettings(extra_args={"temperature": 0.25}) + ) + assert kwargs["temperature"] == 0.25 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_fetch_response_rejects_duplicate_extra_args_with_explicit_value() -> None: + """When the first-class field has an explicit (non-omit) value, a colliding + extra_arg key should still raise a TypeError to surface the conflict.""" + with pytest.raises(TypeError, match="multiple values.*temperature"): + await _run_chat_completions_model_with_custom_base_url( + model_settings=ModelSettings(temperature=0.5, extra_args={"temperature": 0.25}) + ) + + @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_custom_base_url_prompt_cache_key_uses_model_settings_only() -> None: