Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from copy import deepcopy
from functools import partial
from inspect import signature
from itertools import chain
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -40,7 +41,14 @@
SessionContext,
is_local_history_conversation_id,
)
from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools
from ._tools import (
FunctionInvocationConfiguration,
FunctionInvocationLayer,
FunctionTool,
ToolTypes,
normalize_function_invocation_configuration,
normalize_tools,
)
from ._types import (
AgentResponse,
AgentResponseUpdate,
Expand Down Expand Up @@ -163,6 +171,14 @@ def _sanitize_agent_name(agent_name: str | None) -> str | None:
return sanitized


def _accepts_function_invocation_configuration(client: SupportsChatGetResponse[Any]) -> bool:
"""Return whether the client's get_response accepts per-call function invocation config."""
try:
return "function_invocation_configuration" in signature(client.get_response).parameters
except (TypeError, ValueError):
return False
Comment on lines +174 to +179


class _RunContext(TypedDict):
session: AgentSession | None
session_context: SessionContext
Expand All @@ -174,6 +190,7 @@ class _RunContext(TypedDict):
compaction_strategy: CompactionStrategy | None
tokenizer: TokenizerProtocol | None
client_kwargs: Mapping[str, Any]
function_invocation_configuration: FunctionInvocationConfiguration
function_invocation_kwargs: Mapping[str, Any]


Expand Down Expand Up @@ -669,6 +686,7 @@ def __init__(
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
require_per_service_call_history_persistence: bool = False,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: MutableMapping[str, Any] | None = None,
Expand All @@ -691,6 +709,7 @@ def __init__(
is not already storing history. If service-side storage is active for
the run, the agent skips local history providers and relies on the
service-managed conversation instead.
function_invocation_configuration: Optional agent-level function invocation configuration.
default_options: A TypedDict containing chat options. When using a typed agent like
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
provider-specific options including temperature, max_tokens, model,
Expand Down Expand Up @@ -724,6 +743,9 @@ def __init__(
self.compaction_strategy = compaction_strategy
self.require_per_service_call_history_persistence = require_per_service_call_history_persistence
self.tokenizer = tokenizer
self.function_invocation_configuration = normalize_function_invocation_configuration(
function_invocation_configuration
)
Comment on lines +746 to +748

# Get tools from options or named parameter (named param takes precedence)
tools_ = tools if tools is not None else opts.pop("tools", None)
Expand Down Expand Up @@ -989,25 +1011,33 @@ def _call_chat_client(
stream: bool,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Invoke the downstream chat client for a prepared run context."""
function_invocation_configuration_kwargs: dict[str, Any] = {}
if isinstance(self.client, FunctionInvocationLayer) and _accepts_function_invocation_configuration(self.client):
Comment on lines +1014 to +1015
function_invocation_configuration_kwargs["function_invocation_configuration"] = context[
"function_invocation_configuration"
]

if stream:
return self.client.get_response( # type: ignore[call-overload, no-any-return]
return cast(Any, self.client).get_response(
messages=context["session_messages"],
stream=True,
options=context["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=context["compaction_strategy"],
tokenizer=context["tokenizer"],
function_invocation_kwargs=context["function_invocation_kwargs"],
client_kwargs=context["client_kwargs"],
**function_invocation_configuration_kwargs,
)

return self.client.get_response( # type: ignore[call-overload, no-any-return]
return cast(Any, self.client).get_response(
messages=context["session_messages"],
stream=False,
options=context["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=context["compaction_strategy"],
tokenizer=context["tokenizer"],
function_invocation_kwargs=context["function_invocation_kwargs"],
client_kwargs=context["client_kwargs"],
**function_invocation_configuration_kwargs,
)

async def _parse_non_streaming_response(
Expand Down Expand Up @@ -1324,6 +1354,7 @@ async def _prepare_run_context(
"compaction_strategy": compaction_strategy or self.compaction_strategy,
"tokenizer": tokenizer or self.tokenizer,
"client_kwargs": effective_client_kwargs,
"function_invocation_configuration": self.function_invocation_configuration,
"function_invocation_kwargs": additional_function_arguments,
}

Expand Down Expand Up @@ -1689,6 +1720,7 @@ def __init__(
context_providers: Sequence[ContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
require_per_service_call_history_persistence: bool = False,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: MutableMapping[str, Any] | None = None,
Expand All @@ -1705,6 +1737,7 @@ def __init__(
context_providers=context_providers,
middleware=middleware,
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
function_invocation_configuration=function_invocation_configuration,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
Expand Down
37 changes: 24 additions & 13 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,7 @@ def get_response(
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...

Expand All @@ -2308,6 +2309,7 @@ def get_response(
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]]: ...

Expand All @@ -2322,6 +2324,7 @@ def get_response(
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...

Expand All @@ -2335,6 +2338,7 @@ def get_response(
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
from ._middleware import categorize_middleware
Expand All @@ -2358,11 +2362,16 @@ def get_response(
*middleware,
]
runtime_middleware = categorize_middleware(effective_client_kwargs.pop("middleware", []))
effective_function_invocation_configuration = (
normalize_function_invocation_configuration(function_invocation_configuration)
if function_invocation_configuration is not None
else self.function_invocation_configuration
)

function_middleware_pipeline = self._get_function_middleware_pipeline(runtime_middleware["function"])
if runtime_middleware["chat"]:
effective_client_kwargs["middleware"] = runtime_middleware["chat"]
max_errors = self.function_invocation_configuration.get(
max_errors = effective_function_invocation_configuration.get(
"max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST
)
additional_function_arguments = (
Expand All @@ -2377,7 +2386,7 @@ def get_response(
execute_function_calls = partial(
_execute_function_calls,
custom_args=additional_function_arguments,
config=self.function_invocation_configuration,
config=effective_function_invocation_configuration,
invocation_session=invocation_session,
middleware_pipeline=function_middleware_pipeline,
)
Expand All @@ -2388,7 +2397,7 @@ def get_response(
# Remove additional_function_arguments from options passed to underlying chat client
# It's for tool invocation only and not recognized by chat service APIs
mutable_options.pop("additional_function_arguments", None)
if not self.function_invocation_configuration.get("enabled", True):
if not effective_function_invocation_configuration.get("enabled", True):
return super_get_response( # type: ignore[no-any-return]
messages=messages,
stream=stream,
Expand All @@ -2412,14 +2421,16 @@ async def _get_response() -> ChatResponse[Any]:
nonlocal filtered_kwargs
errors_in_a_row: int = 0
total_function_calls: int = 0
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
max_function_calls: int | None = effective_function_invocation_configuration.get("max_function_calls")
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse[Any] | None = None
aggregated_usage: UsageDetails | None = None

loop_enabled = self.function_invocation_configuration.get("enabled", True)
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
loop_enabled = effective_function_invocation_configuration.get("enabled", True)
max_iterations = effective_function_invocation_configuration.get(
"max_iterations", DEFAULT_MAX_ITERATIONS
)
for attempt_idx in range(max_iterations if loop_enabled else 0):
approval_result = await _process_function_requests(
response=None,
Expand Down Expand Up @@ -2509,10 +2520,10 @@ async def _get_response() -> ChatResponse[Any]:
# Make a final model call with tool_choice="none" so the model
# produces a plain text answer instead of leaving orphaned
# function_call items without matching results.
if response is not None and self.function_invocation_configuration.get("enabled", True):
if response is not None and effective_function_invocation_configuration.get("enabled", True):
logger.info(
"Maximum iterations reached (%d). Requesting final response without tools.",
self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS),
effective_function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS),
)
mutable_options["tool_choice"] = "none"
response = cast(
Expand Down Expand Up @@ -2550,13 +2561,13 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
nonlocal stream_result_hooks
errors_in_a_row: int = 0
total_function_calls: int = 0
max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls")
max_function_calls: int | None = effective_function_invocation_configuration.get("max_function_calls")
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse[Any] | None = None

loop_enabled = self.function_invocation_configuration.get("enabled", True)
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
loop_enabled = effective_function_invocation_configuration.get("enabled", True)
max_iterations = effective_function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
for attempt_idx in range(max_iterations if loop_enabled else 0):
approval_result = await _process_function_requests(
response=None,
Expand Down Expand Up @@ -2667,10 +2678,10 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]:
# Make a final model call with tool_choice="none" so the model
# produces a plain text answer instead of leaving orphaned
# function_call items without matching results.
if response is not None and self.function_invocation_configuration.get("enabled", True):
if response is not None and effective_function_invocation_configuration.get("enabled", True):
logger.info(
"Maximum iterations reached (%d). Requesting final response without tools.",
self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS),
effective_function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS),
)
mutable_options["tool_choice"] = "none"
final_inner_stream = cast(
Expand Down
25 changes: 25 additions & 0 deletions python/packages/core/tests/core/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,31 @@ async def fake_inner_get_response(**kwargs):
mock_inner_get_response.assert_called_once()


async def test_base_client_as_agent_forwards_function_invocation_configuration(
chat_client_base: SupportsChatGetResponse,
) -> None:
Comment on lines +80 to +82
captured_config: dict[str, Any] = {}

async def fake_get_response(
*,
function_invocation_configuration: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
assert function_invocation_configuration is not None
captured_config.update(function_invocation_configuration)
return ChatResponse(messages=[Message(role="assistant", contents=["ok"])])

chat_client_base.get_response = fake_get_response # type: ignore[method-assign,attr-defined]

agent = chat_client_base.as_agent(function_invocation_configuration={"include_detailed_errors": True})

await agent.run("hello")
Comment on lines +96 to +98

assert captured_config["include_detailed_errors"] is True
assert captured_config["enabled"] is True
assert chat_client_base.function_invocation_configuration["include_detailed_errors"] is False # type: ignore[attr-defined]


async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse):
response = await chat_client_base.get_response([Message(role="user", contents=["Hello"])])
assert response.messages[0].role == "assistant"
Expand Down