diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 9cb7f05fe3..390e5c1f72 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -108,6 +108,7 @@ retry_policies, ) from .run import ( + OnTurnEndCallback, ReasoningItemIdPolicy, RunConfig, Runner, @@ -115,6 +116,7 @@ ToolErrorFormatterArgs, ToolExecutionConfig, ToolNotFoundBehavior, + TurnEndData, ) from .run_context import AgentHookContext, RunContextWrapper, TContext from .run_error_handlers import ( @@ -447,8 +449,10 @@ def enable_verbose_stdout_logging(): "RunResultStreaming", "ResponsesWebSocketSession", "RunConfig", + "OnTurnEndCallback", "ReasoningItemIdPolicy", "ToolExecutionConfig", + "TurnEndData", "ToolErrorFormatter", "ToolErrorFormatterArgs", "ToolNotFoundBehavior", diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index c84ba44b09..7ed584b19e 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -67,6 +67,26 @@ async def on_handoff( """Called when a handoff occurs.""" pass + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + current_turn: int, + ) -> None: + """Called at the end of each turn, before the next turn begins. + + This fires after all tool calls in the current turn have been executed + and their results processed, but before the next model call. It is + useful for logging, state tracking, dynamic instruction updates, and + context compaction between turns. + + Args: + context: The run context wrapper, including usage and approvals. + agent: The agent that was active during this turn. + current_turn: The turn number that just completed (1-based). + """ + pass + async def on_tool_start( self, context: RunContextWrapper[TContext], @@ -145,6 +165,26 @@ async def on_handoff( off to this agent.""" pass + async def on_turn_end( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + current_turn: int, + ) -> None: + """Called at the end of each turn for this agent, before the next turn begins. + + This fires after all tool calls in the current turn have been executed + and their results processed, but before the next model call. It is + useful for logging, state tracking, dynamic instruction updates, and + context compaction between turns. + + Args: + context: The run context wrapper, including usage and approvals. + agent: This agent instance. + current_turn: The turn number that just completed (1-based). + """ + pass + async def on_tool_start( self, context: RunContextWrapper[TContext], diff --git a/src/agents/run.py b/src/agents/run.py index 014271a5ea..43db492223 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import inspect import warnings from typing import cast @@ -35,6 +36,7 @@ CallModelData, CallModelInputFilter, ModelInputData, + OnTurnEndCallback, ReasoningItemIdPolicy, RunConfig, RunOptions, @@ -42,6 +44,7 @@ ToolErrorFormatterArgs, ToolExecutionConfig, ToolNotFoundBehavior, + TurnEndData, ) from .run_context import RunContextWrapper, TContext from .run_error_handlers import RunErrorHandlers @@ -122,7 +125,7 @@ from .tracing import Span, SpanError, agent_span, get_current_trace, task_span, turn_span from .tracing.context import TraceCtxManager, create_trace_for_run from .tracing.span_data import AgentSpanData, TaskSpanData -from .util import _error_tracing +from .util import _coro, _error_tracing DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore # the value is set at the end of the module @@ -137,17 +140,51 @@ "ModelInputData", "CallModelData", "CallModelInputFilter", + "OnTurnEndCallback", "ReasoningItemIdPolicy", "ToolExecutionConfig", "ToolErrorFormatter", "ToolErrorFormatterArgs", "ToolNotFoundBehavior", + "TurnEndData", "DEFAULT_MAX_TURNS", "set_default_agent_runner", "get_default_agent_runner", ] +async def _invoke_on_turn_end( + *, + hooks: RunHooks[Any] | None, + run_config: RunConfig | None, + agent: Agent[Any], + context_wrapper: RunContextWrapper[Any], + current_turn: int, +) -> None: + """Invoke on_turn_end callbacks from hooks and run_config. + + Called after each turn completes and before the next turn begins. This + fires for handoffs and run-again steps (not for interruptions or final + outputs, which end the run). + """ + if hooks is not None: + tasks = [] + tasks.append(hooks.on_turn_end(context_wrapper, agent, current_turn)) + if agent.hooks is not None: + tasks.append(agent.hooks.on_turn_end(context_wrapper, agent, current_turn)) + await asyncio.gather(*tasks) + + if run_config is not None and run_config.on_turn_end is not None: + turn_end_data = TurnEndData( + agent=agent, + context=context_wrapper, + current_turn=current_turn, + ) + result = run_config.on_turn_end(turn_end_data) + if inspect.isawaitable(result): + await result + + def set_default_agent_runner(runner: AgentRunner | None) -> None: """ WARNING: this class is experimental and not part of the public API @@ -946,6 +983,13 @@ def _finalize_result(result: RunResult) -> RunResult: return _finalize_result(result) if isinstance(turn_result.next_step, NextStepRunAgain): + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) continue append_model_response_if_new( @@ -1020,6 +1064,13 @@ def _finalize_result(result: RunResult) -> RunResult: current_span.finish(reset_current=True) current_span = None should_run_agent_start_hooks = True + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) continue continue @@ -1480,6 +1531,13 @@ def _finalize_result(result: RunResult) -> RunResult: current_span.finish(reset_current=True) current_span = None should_run_agent_start_hooks = True + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) elif isinstance(turn_result.next_step, NextStepRunAgain): await save_turn_items_if_needed( session=session, @@ -1490,6 +1548,13 @@ def _finalize_result(result: RunResult) -> RunResult: response_id=turn_result.model_response.response_id, store=store_setting, ) + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) continue else: raise AgentsException( diff --git a/src/agents/run_config.py b/src/agents/run_config.py index fcc9b01315..2bb8c6d1ef 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -62,6 +62,24 @@ class CallModelData(Generic[TContext]): CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]] + + +@dataclass +class TurnEndData(Generic[TContext]): + """Data passed to ``RunConfig.on_turn_end`` after each turn completes.""" + + agent: Agent[TContext] + """The agent that was active during the just-completed turn.""" + + context: RunContextWrapper[TContext] + """The run context wrapper for the current execution.""" + + current_turn: int + """The turn number that just completed (1-based).""" + + +OnTurnEndCallback = Callable[[TurnEndData[Any]], MaybeAwaitable[None]] + ReasoningItemIdPolicy = Literal["preserve", "omit"] ToolNotFoundBehavior = Literal["raise_error", "return_error_to_model"] @@ -297,6 +315,17 @@ class RunConfig: For example, you can use this to add a system prompt to the input. """ + on_turn_end: OnTurnEndCallback | None = None + """ + Optional callback that is invoked after each turn completes and all tool calls in that turn + have been executed. It fires before the next turn begins, giving applications a hook to + inspect or react to state changes between turns. + + This is useful for logging, state tracking, dynamic instruction updates, and context + compaction in long-running workflows. If a handoff occurred during the turn, the callback + receives the *new* agent (the target of the handoff). + """ + tool_error_formatter: ToolErrorFormatter | None = None """Optional callback that formats tool error messages returned to the model. @@ -366,9 +395,11 @@ class RunOptions(TypedDict, Generic[TContext]): "CallModelData", "CallModelInputFilter", "ModelInputData", + "OnTurnEndCallback", "ReasoningItemIdPolicy", "RunConfig", "RunOptions", + "TurnEndData", "SandboxArchiveLimits", "SandboxConcurrencyLimits", "SandboxRunConfig", diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..97771ddd0b 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -7,6 +7,7 @@ import asyncio import dataclasses as _dc +import inspect import json from collections.abc import Awaitable, Callable, Mapping from typing import Any, TypeVar, cast @@ -63,7 +64,7 @@ response_terminal_failure_error, ) from ..result import RunResultStreaming -from ..run_config import ReasoningItemIdPolicy, RunConfig +from ..run_config import ReasoningItemIdPolicy, RunConfig, TurnEndData from ..run_context import AgentHookContext, RunContextWrapper, TContext from ..run_error_handlers import RunErrorHandlers from ..run_state import RunState @@ -434,6 +435,37 @@ async def _finalize_streamed_interruption( ) +async def _invoke_on_turn_end( + *, + hooks: RunHooks[Any] | None, + run_config: RunConfig | None, + agent: Agent[Any], + context_wrapper: RunContextWrapper[Any], + current_turn: int, +) -> None: + """Invoke on_turn_end callbacks from hooks and run_config. + + Called after each turn completes in the streaming loop and before the next + turn begins. Fires for handoffs and run-again steps. + """ + if hooks is not None: + tasks = [] + tasks.append(hooks.on_turn_end(context_wrapper, agent, current_turn)) + if agent.hooks is not None: + tasks.append(agent.hooks.on_turn_end(context_wrapper, agent, current_turn)) + await asyncio.gather(*tasks) + + if run_config is not None and run_config.on_turn_end is not None: + turn_end_data = TurnEndData( + agent=agent, + context=context_wrapper, + current_turn=current_turn, + ) + result = run_config.on_turn_end(turn_end_data) + if inspect.isawaitable(result): + await result + + T = TypeVar("T") @@ -812,6 +844,13 @@ async def _save_stream_items_without_count( AgentUpdatedStreamEvent(new_agent=current_agent) ) run_state._current_step = NextStepRunAgain() # type: ignore[assignment] + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) continue if isinstance(turn_result.next_step, NextStepFinalOutput): @@ -835,6 +874,13 @@ async def _save_stream_items_without_count( store_setting, ) run_state._current_step = NextStepRunAgain() # type: ignore[assignment] + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) continue run_state._current_step = None @@ -1106,6 +1152,14 @@ async def _save_stream_items_without_count( if streamed_result._state is not None: streamed_result._state._current_step = NextStepRunAgain() + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) + if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1158,6 +1212,14 @@ async def _save_stream_items_without_count( store_setting, ) + await _invoke_on_turn_end( + hooks=hooks, + run_config=run_config, + agent=current_agent, + context_wrapper=context_wrapper, + current_turn=current_turn, + ) + if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel())