From 0d197ef559831e2cdac814b50eb41d41907a6cec Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Thu, 26 Mar 2026 18:21:36 +0000 Subject: [PATCH 1/2] feat(hooks): accept callable hook callbacks in Agent constructor hooks parameter Previously, the hooks parameter in Agent.__init__ only accepted HookProvider instances. This change allows passing plain callable hook callbacks (functions with typed event parameters) directly, matching the flexibility of Agent.add_hook(). The hooks param now accepts a mixed list of HookProvider instances and/or callable hook callbacks: def on_start(event: BeforeInvocationEvent) -> None: print('Starting!') agent = Agent(hooks=[on_start, MyHookProvider()]) This provides a more convenient API for simple hook use cases where creating a full HookProvider class is unnecessary. Changes: - Updated hooks param type: list[HookProvider | HookCallback] | None - Added isinstance check to dispatch HookProviders vs callables - Added ValueError for invalid hook types - Added comprehensive tests (12 test cases) --- src/strands/agent/agent.py | 14 +- .../agent/test_agent_hooks_callable.py | 192 ++++++++++++++++++ 2 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 tests/strands/agent/test_agent_hooks_callable.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..a91372533 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -129,7 +129,7 @@ def __init__( description: str | None = None, state: AgentState | dict | None = None, plugins: list[Plugin] | None = None, - hooks: list[HookProvider] | None = None, + hooks: list[HookProvider | HookCallback] | None = None, session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, @@ -183,7 +183,8 @@ def __init__( Plugins are initialized with the agent instance after construction and can register hooks, modify agent attributes, or perform other setup tasks. Defaults to None. - hooks: hooks to be added to the agent hook registry + hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances + or plain callable hook callbacks (functions with typed event parameters). Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. @@ -322,7 +323,14 @@ def __init__( if hooks: for hook in hooks: - self.hooks.add_hook(hook) + if isinstance(hook, HookProvider): + self.hooks.add_hook(hook) + elif callable(hook): + self.hooks.add_callback(None, hook) + else: + raise ValueError( + f"Invalid hook: {hook!r}. Must be a HookProvider instance or a callable hook callback." + ) if plugins: for plugin in plugins: diff --git a/tests/strands/agent/test_agent_hooks_callable.py b/tests/strands/agent/test_agent_hooks_callable.py new file mode 100644 index 000000000..eb48aa02b --- /dev/null +++ b/tests/strands/agent/test_agent_hooks_callable.py @@ -0,0 +1,192 @@ +"""Tests for accepting callable hook callbacks in Agent constructor's hooks parameter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands import Agent +from strands.hooks import ( + AfterInvocationEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + HookProvider, + HookRegistry, +) +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +class TestHooksParamAcceptsCallables: + """Test that the Agent constructor's hooks parameter accepts both HookProviders and callables.""" + + def test_hooks_param_accepts_callable(self): + """Verify that a plain callable can be passed via hooks parameter.""" + events_received = [] + + def my_callback(event: AgentInitializedEvent) -> None: + events_received.append(event) + + agent = Agent(hooks=[my_callback], callback_handler=None) + + assert len(events_received) == 1 + assert isinstance(events_received[0], AgentInitializedEvent) + assert events_received[0].agent is agent + + def test_hooks_param_accepts_hook_provider(self): + """Verify that HookProvider still works as before (backward compatibility).""" + + class MyProvider(HookProvider): + def __init__(self): + self.events = [] + + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(AgentInitializedEvent, self.on_init) + + def on_init(self, event: AgentInitializedEvent) -> None: + self.events.append(event) + + provider = MyProvider() + agent = Agent(hooks=[provider], callback_handler=None) + + assert len(provider.events) == 1 + assert isinstance(provider.events[0], AgentInitializedEvent) + + def test_hooks_param_accepts_mixed_list(self): + """Verify that a mix of HookProviders and callables can be passed.""" + callback_events = [] + provider_events = [] + + def my_callback(event: AgentInitializedEvent) -> None: + callback_events.append(event) + + class MyProvider(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(AgentInitializedEvent, lambda e: provider_events.append(e)) + + agent = Agent(hooks=[MyProvider(), my_callback], callback_handler=None) + + assert len(callback_events) == 1 + assert len(provider_events) == 1 + assert callback_events[0].agent is agent + assert provider_events[0].agent is agent + + def test_hooks_param_callable_invoked_during_agent_lifecycle(self): + """Verify that callable hooks registered via hooks param fire during agent lifecycle.""" + before_events = [] + after_events = [] + + def on_before(event: BeforeInvocationEvent) -> None: + before_events.append(event) + + def on_after(event: AfterInvocationEvent) -> None: + after_events.append(event) + + mock_model = MockedModelProvider( + [{"role": "assistant", "content": [{"text": "Hello!"}]}] + ) + + agent = Agent( + model=mock_model, + hooks=[on_before, on_after], + callback_handler=None, + ) + agent("test prompt") + + assert len(before_events) == 1 + assert len(after_events) == 1 + assert isinstance(before_events[0], BeforeInvocationEvent) + assert isinstance(after_events[0], AfterInvocationEvent) + + def test_hooks_param_invalid_hook_raises_error(self): + """Verify that passing an invalid hook (not HookProvider or callable) raises ValueError.""" + with pytest.raises(ValueError, match="Invalid hook"): + Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore + + def test_hooks_param_none_is_valid(self): + """Verify that passing None for hooks is still valid.""" + agent = Agent(hooks=None, callback_handler=None) + assert agent is not None + + def test_hooks_param_empty_list_is_valid(self): + """Verify that passing an empty list for hooks is still valid.""" + agent = Agent(hooks=[], callback_handler=None) + assert agent is not None + + def test_hooks_param_callable_with_explicit_type_hint(self): + """Verify that callables with typed event parameters work via hooks param.""" + model_call_events = [] + + def on_model_call(event: BeforeModelCallEvent) -> None: + model_call_events.append(event) + + mock_model = MockedModelProvider( + [{"role": "assistant", "content": [{"text": "result"}]}] + ) + + agent = Agent( + model=mock_model, + hooks=[on_model_call], + callback_handler=None, + ) + agent("prompt") + + assert len(model_call_events) >= 1 + assert isinstance(model_call_events[0], BeforeModelCallEvent) + + def test_hooks_param_lambda_without_type_hint_raises_error(self): + """Verify that lambda functions without type hints raise ValueError.""" + with pytest.raises(ValueError, match="cannot infer event type"): + Agent( + hooks=[lambda event: None], # type: ignore + callback_handler=None, + ) + + def test_hooks_param_multiple_callables(self): + """Verify that multiple callables can be registered.""" + events_a = [] + events_b = [] + + def callback_a(event: AgentInitializedEvent) -> None: + events_a.append(event) + + def callback_b(event: AgentInitializedEvent) -> None: + events_b.append(event) + + agent = Agent(hooks=[callback_a, callback_b], callback_handler=None) + + assert len(events_a) == 1 + assert len(events_b) == 1 + + +class TestHooksParamAsyncCallables: + """Test that the Agent constructor's hooks parameter accepts async callables.""" + + def test_hooks_param_accepts_async_before_invocation_callback(self): + """Verify that async callable hooks can be registered for non-init events.""" + events_received = [] + + async def my_async_callback(event: BeforeInvocationEvent) -> None: + events_received.append(event) + + mock_model = MockedModelProvider( + [{"role": "assistant", "content": [{"text": "Hello!"}]}] + ) + + agent = Agent( + model=mock_model, + hooks=[my_async_callback], + callback_handler=None, + ) + agent("test") + + assert len(events_received) == 1 + assert isinstance(events_received[0], BeforeInvocationEvent) + + def test_hooks_param_rejects_async_agent_initialized_callback(self): + """Verify that async callbacks for AgentInitializedEvent raise ValueError.""" + + async def my_async_callback(event: AgentInitializedEvent) -> None: + pass + + with pytest.raises(ValueError, match="AgentInitializedEvent can only be registered with a synchronous callback"): + Agent(hooks=[my_async_callback], callback_handler=None) From e3f948106904d7339928f156e7bee371131d3d0f Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Fri, 27 Mar 2026 14:31:40 +0000 Subject: [PATCH 2/2] fix: address hatch run prepare lint and formatting issues - Remove unused variable assignments (F841 lint errors) - Apply auto-formatting from hatch run format --- .../agent/test_agent_hooks_callable.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/strands/agent/test_agent_hooks_callable.py b/tests/strands/agent/test_agent_hooks_callable.py index eb48aa02b..d42017505 100644 --- a/tests/strands/agent/test_agent_hooks_callable.py +++ b/tests/strands/agent/test_agent_hooks_callable.py @@ -1,7 +1,5 @@ """Tests for accepting callable hook callbacks in Agent constructor's hooks parameter.""" -from unittest.mock import MagicMock - import pytest from strands import Agent @@ -46,7 +44,7 @@ def on_init(self, event: AgentInitializedEvent) -> None: self.events.append(event) provider = MyProvider() - agent = Agent(hooks=[provider], callback_handler=None) + Agent(hooks=[provider], callback_handler=None) assert len(provider.events) == 1 assert isinstance(provider.events[0], AgentInitializedEvent) @@ -81,9 +79,7 @@ def on_before(event: BeforeInvocationEvent) -> None: def on_after(event: AfterInvocationEvent) -> None: after_events.append(event) - mock_model = MockedModelProvider( - [{"role": "assistant", "content": [{"text": "Hello!"}]}] - ) + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) agent = Agent( model=mock_model, @@ -119,9 +115,7 @@ def test_hooks_param_callable_with_explicit_type_hint(self): def on_model_call(event: BeforeModelCallEvent) -> None: model_call_events.append(event) - mock_model = MockedModelProvider( - [{"role": "assistant", "content": [{"text": "result"}]}] - ) + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "result"}]}]) agent = Agent( model=mock_model, @@ -152,7 +146,7 @@ def callback_a(event: AgentInitializedEvent) -> None: def callback_b(event: AgentInitializedEvent) -> None: events_b.append(event) - agent = Agent(hooks=[callback_a, callback_b], callback_handler=None) + Agent(hooks=[callback_a, callback_b], callback_handler=None) assert len(events_a) == 1 assert len(events_b) == 1 @@ -168,9 +162,7 @@ def test_hooks_param_accepts_async_before_invocation_callback(self): async def my_async_callback(event: BeforeInvocationEvent) -> None: events_received.append(event) - mock_model = MockedModelProvider( - [{"role": "assistant", "content": [{"text": "Hello!"}]}] - ) + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) agent = Agent( model=mock_model, @@ -188,5 +180,7 @@ def test_hooks_param_rejects_async_agent_initialized_callback(self): async def my_async_callback(event: AgentInitializedEvent) -> None: pass - with pytest.raises(ValueError, match="AgentInitializedEvent can only be registered with a synchronous callback"): + with pytest.raises( + ValueError, match="AgentInitializedEvent can only be registered with a synchronous callback" + ): Agent(hooks=[my_async_callback], callback_handler=None)