diff --git a/.sampo/changesets/ardent-knight-vainamoinen.md b/.sampo/changesets/ardent-knight-vainamoinen.md new file mode 100644 index 00000000..3ef2874f --- /dev/null +++ b/.sampo/changesets/ardent-knight-vainamoinen.md @@ -0,0 +1,5 @@ +--- +pypi/posthog: patch +--- + +feat: add Celery integration and improve PostHog client fork safety diff --git a/examples/celery_integration.py b/examples/celery_integration.py new file mode 100644 index 00000000..a2adc93f --- /dev/null +++ b/examples/celery_integration.py @@ -0,0 +1,191 @@ +""" +Celery integration example for the PostHog Python SDK. + +Demonstrates how to use ``PosthogCeleryIntegration`` with: +- producer-side instrumentation (publishing events and context propagation) +- worker-side instrumentation via ``worker_process_init`` (prefork-safe) +- context propagation (distinct ID, session ID, tags) from producer to worker +- task lifecycle events (published, started, success, failure, retry) +- exception capture from failed tasks +- ``task_filter`` customization hook + +Setup: + 1. Update POSTHOG_PROJECT_API_KEY and POSTHOG_HOST here with your credentials + (environment variables won't work as it's better if Celery forks worker into + separate process for the example to prove context propagation) + 2. Install dependencies: pip install posthog celery redis + 3. Start Redis: redis-server + 4. Start the worker: celery -A examples.celery_integration worker --loglevel=info + 5. Run the producer: python -m examples.celery_integration +""" + +import time +from typing import Any, Optional + +from celery import Celery +from celery.signals import worker_process_init, worker_process_shutdown + +import posthog +from posthog.client import Client +from posthog.integrations.celery import PosthogCeleryIntegration + + +# --- Configuration --- + +POSTHOG_PROJECT_API_KEY = "phc_..." +POSTHOG_HOST = "http://localhost:8000" + +app = Celery( + "examples.celery_integration", + broker="redis://localhost:6379/0", +) + + +# --- Integration wiring --- + +def create_client() -> Client: + return Client( + project_api_key=POSTHOG_PROJECT_API_KEY, + host=POSTHOG_HOST + ) + + +def task_filter(task_name: Optional[str], task_properties: dict[str, Any]) -> bool: + if task_name is not None and task_name.endswith(".health_check"): + return False + return True + + +def create_integration(client: Client) -> PosthogCeleryIntegration: + return PosthogCeleryIntegration( + client=client, + capture_exceptions=True, + capture_task_lifecycle_events=True, + propagate_context=True, + task_filter=task_filter, + ) + + +# Worker process setup. +# Celery's default prefork pool runs tasks in child processes, so initialize +# PostHog per child using worker_process_init. + + +@worker_process_init.connect +def on_worker_process_init(**kwargs) -> None: + worker_posthog_client = create_client() + worker_integration = create_integration(worker_posthog_client) + worker_integration.instrument() + + app._posthog_client = worker_posthog_client + app._posthog_integration = worker_integration + + +@worker_process_shutdown.connect +def on_worker_process_shutdown(**kwargs) -> None: + worker_integration = getattr(app, "_posthog_integration", None) + if worker_integration: + worker_integration.uninstrument() + + worker_posthog_client = getattr(app, "_posthog_client", None) + if worker_posthog_client: + worker_posthog_client.shutdown() + +# --- Example tasks --- + +@app.task +def health_check() -> dict[str, str]: + return {"status": "ok"} + + +@app.task(bind=True, max_retries=3) +def process_order(self, order_id: str) -> dict: + """A task that processes an order successfully.""" + + # simulate work + time.sleep(0.1) + + # Custom event inside the task - context tags propagated from the + # producer (e.g. "source", "release") should appear on this event + # and this should be attributed to the correct distinct ID and session. + app._posthog_client.capture( + "celery example order processed", + properties={"order_id": order_id, "amount": 99.99}, + ) + + return {"order_id": order_id, "status": "completed"} + + +@app.task(bind=True, max_retries=3) +def send_notification(self, user_id: str, message: str) -> None: + """A task that may fail and retry.""" + if self.request.retries < 2: + raise self.retry( + exc=ConnectionError("notification service unavailable"), + countdown=120, + ) + return None + + +@app.task +def failing_task() -> None: + """A task that always fails.""" + raise ValueError("something went wrong") + + +# --- Producer code --- + +if __name__ == "__main__": + posthog_client = create_client() + integration = create_integration(posthog_client) + integration.instrument() + + print("PostHog Celery Integration Example") + print("=" * 40) + print() + + # Set up PostHog context before dispatching tasks. + # The integration propagates this context to workers via task headers. + with posthog.new_context(fresh=True, client=posthog_client): + posthog.identify_context("user-123") + posthog.set_context_session("session-user-123-abc") + posthog.tag("source", "celery_integration_example_script") + posthog.tag("release", "v1.2.3") + + print("Dispatching tasks...") + + # This task is intentionally filtered and should not emit task events. + result = health_check.delay() + print(f" health_check dispatched (filtered): {result.id}") + + # This task will produce events: + # celery task published (sender side) + # celery task started (worker side) + # order processed (custom event, should carry propagated context tags) + # celery task success (worker side, includes duration) + result = process_order.delay("order-456") + print(f" process_order dispatched: {result.id}") + + # This task will produce events: + # celery task published + # celery task started + # celery task retry (with reason) + # celery task started (retry attempt) + # celery task success + result = send_notification.delay("user-123", "Hello!") + print(f" send_notification dispatched: {result.id}") + + # This task will produce events: + # celery task published + # celery task started + # celery task failure (with error_type and error_message) + result = failing_task.delay() + print(f" failing_task dispatched: {result.id}") + + print() + print("Tasks dispatched. Check your Celery worker logs and PostHog for events.") + print() + + posthog_client.flush() + integration.uninstrument() + posthog_client.shutdown() diff --git a/posthog/client.py b/posthog/client.py index 0157b1df..b936bb27 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -3,6 +3,7 @@ import os import sys import warnings +import weakref from datetime import datetime, timedelta from typing import Any, Dict, Optional, Union from uuid import uuid4 @@ -219,6 +220,7 @@ def __init__( Category: Initialization """ + self._max_queue_size = max_queue_size self.queue = queue.Queue(max_queue_size) # api_key: This should be the Team API Key (token), public @@ -332,6 +334,10 @@ def __init__( if send: consumer.start() + if hasattr(os, "register_at_fork"): + weak_self = weakref.ref(self) + os.register_at_fork(after_in_child=lambda: Client._reinit_after_fork_weak(weak_self)) + def new_context(self, fresh=False, capture_exceptions=True): """ Create a new context for managing shared state. Learn more about [contexts](/docs/libraries/python#contexts). @@ -1080,6 +1086,55 @@ def capture_exception( except Exception as e: self.log.exception(f"Failed to capture exception: {e}") + @staticmethod + def _reinit_after_fork_weak(weak_self): + """ + Reinitialize the client after a fork. + Garbage collected if the client is deleted. + """ + self = weak_self() + if self is None: + return + self._reinit_after_fork() + + def _reinit_after_fork(self): + """Reinitialize queue and consumer threads in a forked child process. + + Registered via os.register_at_fork(after_in_child=...) so it runs + exactly once in each child, before any user code, covering all code + paths (capture, flush, join, etc.). + + Python threads do not survive fork() and queue.Queue internal locks + may be in an inconsistent state, so both are replaced. + Inherited queue items are intentionally discarded as they'll be + handled by the parent process's consumers. + """ + if self.consumers is None: + return + + self.queue = queue.Queue(self._max_queue_size) + + new_consumers = [] + for old in self.consumers: + consumer = Consumer( + self.queue, + old.api_key, + flush_at=old.flush_at, + host=old.host, + on_error=old.on_error, + flush_interval=old.flush_interval, + gzip=old.gzip, + retries=old.retries, + timeout=old.timeout, + historical_migration=old.historical_migration, + ) + new_consumers.append(consumer) + + if self.send: + consumer.start() + + self.consumers = new_consumers + def _enqueue(self, msg, disable_geoip): # type: (...) -> Optional[str] """Push a new `msg` onto the queue, return `(success, msg)`""" diff --git a/posthog/integrations/celery.py b/posthog/integrations/celery.py new file mode 100644 index 00000000..cb7f2d3b --- /dev/null +++ b/posthog/integrations/celery.py @@ -0,0 +1,401 @@ +""" +Integration for `celery`_ to capture task lifecycle events and exceptions with PostHog. + +.. _celery: https://pypi.org/project/celery/ + +Features: +- Hooks into Celery signals to automatically capture task lifecycle events + (started, success, failure, retry, published) and exceptions. +- Lifecycle events include Celery-specific properties such as task ID, task name, + queue, retry count, duration, Celery version etc. +- Any custom events captured inside a task (via ``client.capture``) are automatically + enriched with the same Celery-specific properties via context tags. +- Propagates PostHog context (distinct ID, session ID, tags) from the producer + process to the worker process. + +Supports Celery 4.0+ (Message Protocol Version 2). + +Usage +----- + +.. code-block:: python + + from posthog import Posthog + from posthog.integrations.celery import PosthogCeleryIntegration + + posthog = Posthog("", host="") + + integration = PosthogCeleryIntegration(client=posthog) + integration.instrument() + +Both the producer process and each worker process must initialize the +PostHog client and instrument the integration because the worker needs +to bind to Celery signals, and the PostHog client may use background threads +to send captured events (depending on ``sync_mode``). Celery provides a signal +called ``worker_process_init`` that can be used to accomplish this. + +See ``examples/celery_integration.py`` for a complete working example. + +Supported task states for event emission: + - ``started`` + - ``success`` + - ``failure`` + - ``retry`` + - ``published`` + +Event properties: + All lifecycle and exception events include the following properties: + + - ``celery_task_id`` -- unique task ID + - ``celery_task_name`` -- registered task name + - ``celery_state`` -- lifecycle state (started, success, failure, etc.) + - ``celery_hostname`` -- worker hostname + - ``celery_exchange`` -- broker exchange + - ``celery_routing_key`` -- broker routing key + - ``celery_queue`` -- broker queue name + - ``celery_retry_count`` -- number of retries so far + - ``celery_version`` -- installed Celery library version + - ``celery_task_duration_ms`` -- task wall-clock duration in milliseconds + (present on terminal states: success, failure, retry) + + Additional properties on specific states: + + - **failure**: ``error_type``, ``error_message`` + - **retry**: ``celery_reason`` +""" + +import json +import logging +import time +from typing import Any, Callable, Optional + +from posthog import contexts +from posthog.client import Client + + +CONTEXT_DISTINCT_ID_HEADER = "X-POSTHOG-DISTINCT-ID" +CONTEXT_SESSION_ID_HEADER = "X-POSTHOG-SESSION-ID" +CONTEXT_TAGS_HEADER = "X-POSTHOG-CONTEXT-TAGS" + +logger = logging.getLogger("posthog") + + +class PosthogCeleryIntegration: + """Celery integration that captures task lifecycle events and exceptions. + + Args: + client: Optional ``Client`` instance. When provided, all events and + exceptions are captured through this client rather than the + global ``posthog`` module. + capture_exceptions: Whether to capture task exceptions via + ``capture_exception`` (default ``True``). + capture_task_lifecycle_events: Whether to emit lifecycle events of the task + such as "started", "success", "failure" etc. (default ``True``). + propagate_context: Whether to propagate PostHog context (distinct + ID, session ID, tags) from the producer to the worker via task + headers (default ``True``). + task_filter: Optional callback ``(task_name, task_properties) -> bool`` expected to + return ``False`` if a given task should not be tracked. + """ + + def __init__( + self, + client: Optional[Client] = None, + capture_exceptions: bool = True, + capture_task_lifecycle_events: bool = True, + propagate_context: bool = True, + task_filter: Optional[Callable[[Optional[str], dict[str, Any]], bool]] = None, + ): + self.client = client + self.capture_exceptions = capture_exceptions + self.capture_task_lifecycle_events = capture_task_lifecycle_events + self.propagate_context = propagate_context + self.task_filter = task_filter + + self._instrumented = False + self._signals: Optional[Any] = None + self._celery_version: Optional[str] = None + + def instrument(self) -> None: + if self._instrumented: + return + + from celery import signals + from celery import __version__ as celery_version + + self._signals = signals + self._celery_version = celery_version + + signals.task_prerun.connect(self._on_task_prerun, weak=False) + signals.task_success.connect(self._on_task_success, weak=False) + signals.task_failure.connect(self._on_task_failure, weak=False) + signals.task_retry.connect(self._on_task_retry, weak=False) + signals.before_task_publish.connect(self._on_before_task_publish, weak=False) + signals.after_task_publish.connect(self._on_after_task_publish, weak=False) + + self._instrumented = True + + def uninstrument(self) -> None: + if not self._instrumented or not self._signals: + return + + self._signals.task_prerun.disconnect(self._on_task_prerun) + self._signals.task_success.disconnect(self._on_task_success) + self._signals.task_failure.disconnect(self._on_task_failure) + self._signals.task_retry.disconnect(self._on_task_retry) + self._signals.before_task_publish.disconnect(self._on_before_task_publish) + self._signals.after_task_publish.disconnect(self._on_after_task_publish) + + self._signals = None + self._instrumented = False + + def _on_before_task_publish(self, *args, **kwargs): + try: + if not self.propagate_context: + return + + headers = kwargs.get("headers") + if not isinstance(headers, dict): + return + + distinct_id = contexts.get_context_distinct_id() + session_id = contexts.get_context_session_id() + tags = contexts.get_tags() + + posthog_headers: dict[str, str] = {} + if distinct_id: + posthog_headers[CONTEXT_DISTINCT_ID_HEADER] = distinct_id + if session_id: + posthog_headers[CONTEXT_SESSION_ID_HEADER] = session_id + if tags: + posthog_headers[CONTEXT_TAGS_HEADER] = json.dumps(tags, default=str) + + if posthog_headers: + headers.update(posthog_headers) + # https://github.com/celery/celery/issues/4875 + # In Celery protocol v2, top-level custom headers do not + # reliably appear in task.request.headers on the worker. + # Only headers nested inside headers["headers"] survive. + # Both sentry-sdk and dd-trace-py use this same workaround. + headers.setdefault("headers", {}).update(posthog_headers) + except Exception: + logger.exception("Failed to propagate PostHog context in before_task_publish") + + def _on_after_task_publish(self, *args, **kwargs): + try: + if not self.capture_task_lifecycle_events: + return + + sender = kwargs.get("sender") # contains task name for publish events, NOT task object + headers = kwargs.get("headers") + task_id = headers.get("id") if isinstance(headers, dict) else None + + sender_properties = { + "celery_task_id": task_id, + "celery_task_name": sender, + "celery_state": "published", + "celery_exchange": kwargs.get("exchange"), + "celery_routing_key": kwargs.get("routing_key"), + "celery_hostname": None, # Not available at publish time (no worker assigned yet) + "celery_retry_count": headers.get("retries") if isinstance(headers, dict) else None, + "celery_version": self._celery_version, + } + + if self._should_track(sender, sender_properties): + self._capture_event("celery task published", properties=sender_properties) + except Exception: + logger.exception("Failed to capture Celery after_task_publish lifecycle event") + + def _on_task_prerun(self, *args, **kwargs): + context_manager = None + try: + task_id = kwargs.get("task_id") + if not task_id: + return + + sender = kwargs.get("sender") + request = getattr(sender, "request", None) + context_tags = self._extract_propagated_tags(request) + task_properties = self._build_task_properties( + sender=sender, + task_id=task_id, + state="started", + ) + task_name = task_properties.get("celery_task_name") + + if request is not None: + context_manager = contexts.new_context( + fresh=True, # to prevent context bleed across tasks + capture_exceptions=False, # Celery catches task exceptions internally and + # delivers them via task_failure signal, so they + # never propagate through the context manager. + # We capture them in _on_task_failure. + client=self.client, + ) + context_manager.__enter__() + request._posthog_ctx = context_manager + request._posthog_start = time.monotonic() + + self._apply_propagated_identity(request) + + merged_tags = {**task_properties, **context_tags} + for key, value in merged_tags.items(): + contexts.tag(key, value) + + if self.capture_task_lifecycle_events and self._should_track(task_name, task_properties): + self._capture_event("celery task started", properties=task_properties) + except Exception: + logger.exception("Failed to process Celery task_prerun") + if context_manager is not None: + try: + context_manager.__exit__(None, None, None) + except Exception: + pass + + def _on_task_success(self, *args, **kwargs): + self._handle_task_end("success", **kwargs) + + def _on_task_failure(self, *args, **kwargs): + self._handle_task_end("failure", **kwargs) + + def _on_task_retry(self, *args, **kwargs): + self._handle_task_end("retry", extra_properties={ + "celery_reason": str(kwargs.get("reason")), + }, **kwargs) + + def _handle_task_end( + self, + state: str, + extra_properties: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + sender = kwargs.get("sender") + request = getattr(sender, "request", None) + + try: + task_id = kwargs.get("task_id") + if task_id is None: + task_id = getattr(request, "id", None) + + task_properties = self._build_task_properties( + sender=sender, + task_id=task_id, + state=state, + ) + if extra_properties: + task_properties.update(extra_properties) + + self._add_duration(request, task_properties) + + exception = kwargs.get("exception") + if exception: + task_properties["error_type"] = type(exception).__name__ + task_properties["error_message"] = str(exception) + if self.capture_exceptions: + self._capture_exception(exception) + + task_name = task_properties.get("celery_task_name") + if self.capture_task_lifecycle_events and self._should_track(task_name, task_properties): + self._capture_event(f"celery task {state}", properties=task_properties) + except Exception: + logger.exception("Failed to process Celery %s", state) + finally: + ctx = getattr(request, "_posthog_ctx", None) + if ctx is not None: + ctx.__exit__(None, None, None) + + def _apply_propagated_identity(self, request: Any) -> None: + headers = self._extract_headers(request) + distinct_id = headers.get(CONTEXT_DISTINCT_ID_HEADER) + if distinct_id: + contexts.identify_context(str(distinct_id)) + + session_id = headers.get(CONTEXT_SESSION_ID_HEADER) + if session_id: + contexts.set_context_session(str(session_id)) + + def _extract_propagated_tags(self, request: Any) -> dict[str, Any]: + headers = self._extract_headers(request) + + try: + parsed = json.loads(headers.get(CONTEXT_TAGS_HEADER)) + except Exception: + return {} + + if isinstance(parsed, dict): + return parsed + return {} + + def _extract_headers(self, request: Any) -> dict[str, Any]: + if request is None: + return {} + + # On the Celery worker, request.headers maps to the nested + # message["headers"]["headers"] dict (see celery#4875), which is + # where _on_before_task_publish places PostHog context headers. + headers = getattr(request, "headers", None) + if isinstance(headers, dict): + return headers + + if isinstance(request, dict): + dict_headers = request.get("headers") + if isinstance(dict_headers, dict): + return dict_headers + + return {} + + def _build_task_properties( + self, + sender=None, + task_id=None, + state=None, + ) -> dict[str, Any]: + request = getattr(sender, "request", None) + delivery_info = getattr(request, "delivery_info", None) + delivery_info = delivery_info if isinstance(delivery_info, dict) else {} + + properties = { + "celery_task_id": task_id, + "celery_task_name": getattr(sender, "name", None), + "celery_state": state, + "celery_hostname": getattr(request, "hostname", None), + "celery_exchange": delivery_info.get("exchange"), + "celery_routing_key": delivery_info.get("routing_key"), + "celery_queue": delivery_info.get("queue"), + "celery_retry_count": getattr(request, "retries", None), + "celery_version": self._celery_version, + } + return properties + + def _add_duration(self, request: Any, task_properties: dict[str, Any]) -> None: + start_time = getattr(request, "_posthog_start", None) + if start_time is not None: + task_properties["celery_task_duration_ms"] = round( + (time.monotonic() - start_time) * 1000.0, 3 + ) + + def _should_track(self, task_name: Optional[str], task_properties: dict[str, Any]) -> bool: + if self.task_filter: + return bool(self.task_filter(task_name, task_properties)) + return True + + def _capture_event(self, event: str, properties: dict[str, Any]) -> None: + if self.client: + self.client.capture(event, properties=properties) + else: + from posthog import capture + + capture(event, properties=properties) + + def _capture_exception(self, exception: Exception) -> None: + if self.client: + self.client.capture_exception(exception) + else: + from posthog import capture_exception + + capture_exception(exception) + + +__all__ = [ + "PosthogCeleryIntegration", +] diff --git a/posthog/test/integrations/test_celery_integration.py b/posthog/test/integrations/test_celery_integration.py new file mode 100644 index 00000000..eb530b7a --- /dev/null +++ b/posthog/test/integrations/test_celery_integration.py @@ -0,0 +1,492 @@ +import unittest +from types import ModuleType +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from posthog import contexts +from posthog.integrations.celery import ( + CONTEXT_DISTINCT_ID_HEADER, + CONTEXT_SESSION_ID_HEADER, + CONTEXT_TAGS_HEADER, + PosthogCeleryIntegration, +) + + +class FakeSignal: + def __init__(self): + self.connected = [] + self.disconnected = [] + + def connect(self, handler, weak=False): + self.connected.append((handler, weak)) + + def disconnect(self, handler): + self.disconnected.append(handler) + + +class TestPosthogCeleryIntegration(unittest.TestCase): + def test_instrument_is_idempotent(self): + fake_signals = SimpleNamespace( + task_prerun=FakeSignal(), + task_success=FakeSignal(), + task_failure=FakeSignal(), + task_retry=FakeSignal(), + before_task_publish=FakeSignal(), + after_task_publish=FakeSignal(), + ) + + integration = PosthogCeleryIntegration() + fake_celery = ModuleType("celery") + fake_celery.signals = fake_signals + fake_celery.__version__ = "5.0.0" + + with patch.dict("sys.modules", {"celery": fake_celery}): + integration.instrument() + integration.instrument() + + for sig in [ + "task_prerun", + "task_success", + "task_failure", + "task_retry", + "before_task_publish", + "after_task_publish", + ]: + self.assertEqual(len(getattr(fake_signals, sig).connected), 1, f"{sig} connected multiple times") + + def test_instrument_and_uninstrument_connect_signals(self): + fake_signals = SimpleNamespace( + task_prerun=FakeSignal(), + task_success=FakeSignal(), + task_failure=FakeSignal(), + task_retry=FakeSignal(), + before_task_publish=FakeSignal(), + after_task_publish=FakeSignal(), + ) + + integration = PosthogCeleryIntegration() + + fake_celery = ModuleType("celery") + fake_celery.signals = fake_signals + fake_celery.__version__ = "5.0.0" + + with patch.dict("sys.modules", {"celery": fake_celery}): + integration.instrument() + integration.uninstrument() + + for sig in ["task_prerun", "task_success", "task_failure", + "task_retry", "before_task_publish", + "after_task_publish"]: + self.assertEqual(len(getattr(fake_signals, sig).connected), 1, f"{sig} not connected") + self.assertEqual(len(getattr(fake_signals, sig).disconnected), 1, f"{sig} not disconnected") + + def test_before_task_publish_propagates_context_headers(self): + integration = PosthogCeleryIntegration() + headers = {} + + with contexts.new_context(fresh=True): + contexts.identify_context("distinct-123") + contexts.set_context_session("session-456") + contexts.tag("request_id", "abc") + + integration._on_before_task_publish(sender="test.task", headers=headers) + + self.assertEqual(headers[CONTEXT_DISTINCT_ID_HEADER], "distinct-123") + self.assertEqual(headers[CONTEXT_SESSION_ID_HEADER], "session-456") + self.assertIn(CONTEXT_TAGS_HEADER, headers) + + # celery#4875: headers must also be nested inside headers["headers"] + # so they survive to task.request.headers on the worker + inner = headers["headers"] + self.assertEqual(inner[CONTEXT_DISTINCT_ID_HEADER], "distinct-123") + self.assertEqual(inner[CONTEXT_SESSION_ID_HEADER], "session-456") + self.assertIn(CONTEXT_TAGS_HEADER, inner) + + def test_before_task_publish_preserves_existing_nested_headers(self): + integration = PosthogCeleryIntegration() + headers = {"headers": {"sentry-trace": "abc-123"}} + + with contexts.new_context(fresh=True): + contexts.identify_context("distinct-123") + integration._on_before_task_publish(sender="test.task", headers=headers) + + inner = headers["headers"] + self.assertEqual(inner["sentry-trace"], "abc-123") + self.assertEqual(inner[CONTEXT_DISTINCT_ID_HEADER], "distinct-123") + + def test_before_task_publish_nested_headers_round_trips_to_worker(self): + integration = PosthogCeleryIntegration(client=Mock()) + headers = {} + + with contexts.new_context(fresh=True): + contexts.identify_context("user-rt") + contexts.set_context_session("sess-rt") + contexts.tag("env", "test") + integration._on_before_task_publish(sender="test.task", headers=headers) + + # Simulate Celery worker: task.request.headers is the nested dict + worker_request = SimpleNamespace( + headers=headers["headers"], + delivery_info={}, + hostname="worker-1", + retries=0, + ) + task = SimpleNamespace(name="test.task", request=worker_request) + + integration._on_task_prerun(sender=task, task_id="task-rt") + + self.assertEqual(contexts.get_context_distinct_id(), "user-rt") + self.assertEqual(contexts.get_context_session_id(), "sess-rt") + self.assertEqual(contexts.get_tags().get("env"), "test") + + integration._on_task_success(sender=task) + + def test_before_task_publish_respects_propagate_context_flag(self): + integration = PosthogCeleryIntegration(propagate_context=False) + headers = {} + + with contexts.new_context(fresh=True): + contexts.identify_context("distinct-123") + contexts.set_context_session("session-456") + contexts.tag("request_id", "abc") + + integration._on_before_task_publish(sender="test.task", headers=headers) + + self.assertEqual(headers, {}) + + def test_task_context_is_cleared_after_task_end(self): + integration = PosthogCeleryIntegration(client=Mock()) + + first_request = SimpleNamespace( + headers={ + CONTEXT_DISTINCT_ID_HEADER: "user-1", + CONTEXT_SESSION_ID_HEADER: "sess-1", + CONTEXT_TAGS_HEADER: '{"source": "api"}', + }, + delivery_info={}, + hostname="worker-1", + retries=0, + ) + first_task = SimpleNamespace(name="app.tasks.first", request=first_request) + + integration._on_task_prerun(sender=first_task, task_id="task-1") + + self.assertEqual(contexts.get_context_distinct_id(), "user-1") + self.assertEqual(contexts.get_context_session_id(), "sess-1") + self.assertEqual(contexts.get_tags().get("source"), "api") + + integration._on_task_success(sender=first_task, task_id="task-1") + + self.assertIsNone(contexts.get_context_distinct_id()) + self.assertIsNone(contexts.get_context_session_id()) + self.assertEqual(contexts.get_tags(), {}) + + second_request = SimpleNamespace( + headers={}, + delivery_info={}, + hostname="worker-1", + retries=0, + ) + second_task = SimpleNamespace(name="app.tasks.second", request=second_request) + + integration._on_task_prerun(sender=second_task, task_id="task-2") + + self.assertIsNone(contexts.get_context_distinct_id()) + self.assertIsNone(contexts.get_context_session_id()) + self.assertNotIn("source", contexts.get_tags()) + + integration._on_task_success(sender=second_task, task_id="task-2") + + def test_task_prerun_hydrates_context_and_postrun_cleans_up(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + request = SimpleNamespace( + headers={ + CONTEXT_DISTINCT_ID_HEADER: "user-1", + CONTEXT_SESSION_ID_HEADER: "sess-1", + CONTEXT_TAGS_HEADER: '{"source": "api"}', + }, + delivery_info={"exchange": "celery", "routing_key": "default", "queue": "default"}, + hostname="worker-1", + retries=0, + ) + task = SimpleNamespace(name="app.tasks.example", request=request) + + integration._on_task_prerun(sender=task, task_id="task-1") + + self.assertEqual(contexts.get_context_distinct_id(), "user-1") + self.assertEqual(contexts.get_context_session_id(), "sess-1") + self.assertEqual(contexts.get_tags().get("source"), "api") + self.assertTrue(hasattr(request, "_posthog_ctx")) + self.assertTrue(hasattr(request, "_posthog_start")) + + integration._on_task_success(sender=task) + + event_names = [call.args[0] for call in mock_client.capture.call_args_list] + self.assertIn("celery task started", event_names) + self.assertIn("celery task success", event_names) + + def test_postrun_includes_duration(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + request = SimpleNamespace( + headers={}, + delivery_info={}, + hostname="worker-1", + retries=0, + ) + task = SimpleNamespace(name="app.tasks.timed", request=request) + + integration._on_task_prerun(sender=task, task_id="task-t") + integration._on_task_success(sender=task) + + completed_call = [ + c for c in mock_client.capture.call_args_list if c.args[0] == "celery task success" + ] + self.assertEqual(len(completed_call), 1) + self.assertIn("celery_task_duration_ms", completed_call[0].kwargs["properties"]) + + def test_failure_includes_duration(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + request = SimpleNamespace( + headers={}, + delivery_info={}, + hostname="worker-1", + retries=0, + ) + task = SimpleNamespace(name="app.tasks.failing_timed", request=request) + + integration._on_task_prerun(sender=task, task_id="task-f") + integration._on_task_failure( + sender=task, task_id="task-f", exception=ValueError("boom") + ) + + failed_call = [ + c for c in mock_client.capture.call_args_list if c.args[0] == "celery task failure" + ] + self.assertEqual(len(failed_call), 1) + self.assertIn("celery_task_duration_ms", failed_call[0].kwargs["properties"]) + + def test_task_failure_captures_exception_and_failure_event(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + task = SimpleNamespace(name="app.tasks.failing", request=SimpleNamespace(delivery_info={})) + exception = ValueError("task failed") + + integration._on_task_failure( + sender=task, + task_id="task-2", + exception=exception, + ) + + mock_client.capture_exception.assert_called_once_with(exception) + event_names = [call.args[0] for call in mock_client.capture.call_args_list] + self.assertIn("celery task failure", event_names) + + def test_task_failure_event_includes_error_fields(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + task = SimpleNamespace( + name="app.tasks.failing", + request=SimpleNamespace(delivery_info={}), + ) + exception = ValueError("task failed") + + integration._on_task_failure( + sender=task, + task_id="task-2", + exception=exception, + ) + + failure_calls = [ + c for c in mock_client.capture.call_args_list if c.args[0] == "celery task failure" + ] + self.assertEqual(len(failure_calls), 1) + props = failure_calls[0].kwargs["properties"] + self.assertEqual(props["error_type"], "ValueError") + self.assertEqual(props["error_message"], "task failed") + + def test_task_failure_skips_exception_capture_when_disabled(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client, capture_exceptions=False) + + task = SimpleNamespace(name="app.tasks.failing", request=SimpleNamespace(delivery_info={})) + exception = ValueError("task failed") + + integration._on_task_failure( + sender=task, + task_id="task-2", + exception=exception, + ) + + mock_client.capture_exception.assert_not_called() + event_names = [call.args[0] for call in mock_client.capture.call_args_list] + self.assertIn("celery task failure", event_names) + + def test_task_retry_captures_event_with_reason(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + task = SimpleNamespace(name="app.tasks.retrying", request=SimpleNamespace(delivery_info={})) + + integration._on_task_retry( + sender=task, + task_id="task-retry", + reason=ConnectionError("broker down"), + ) + + event_names = [call.args[0] for call in mock_client.capture.call_args_list] + self.assertIn("celery task retry", event_names) + retry_call = [c for c in mock_client.capture.call_args_list if c.args[0] == "celery task retry"][0] + props = retry_call.kwargs["properties"] + self.assertEqual(props["celery_reason"], "broker down") + + def test_task_filter_applies_to_worker_lifecycle_events(self): + mock_client = Mock() + integration = PosthogCeleryIntegration( + client=mock_client, + task_filter=lambda task_name, properties: False, + ) + + request = SimpleNamespace( + headers={}, + delivery_info={}, + hostname="worker-1", + retries=0, + ) + task = SimpleNamespace(name="app.tasks.filtered", request=request) + + integration._on_task_prerun(sender=task, task_id="task-3") + integration._on_task_success(sender=task, task_id="task-3") + + mock_client.capture.assert_not_called() + + def test_task_failure_captures_exception_when_lifecycle_events_disabled(self): + mock_client = Mock() + integration = PosthogCeleryIntegration( + client=mock_client, + capture_task_lifecycle_events=False, + ) + + task = SimpleNamespace( + name="app.tasks.failing", + request=SimpleNamespace(delivery_info={}), + ) + exception = ValueError("task failed") + + integration._on_task_failure( + sender=task, + task_id="task-4", + exception=exception, + ) + + mock_client.capture.assert_not_called() + mock_client.capture_exception.assert_called_once_with(exception) + + def test_after_task_publish_captures_published_event(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + integration._on_after_task_publish( + sender="app.tasks.published", + headers={"id": "task-3"}, + exchange="celery", + routing_key="default", + ) + + mock_client.capture.assert_called_once() + self.assertEqual(mock_client.capture.call_args.args[0], "celery task published") + props = mock_client.capture.call_args.kwargs["properties"] + self.assertIn("celery_version", props) + + def test_after_task_publish_respects_task_filter(self): + mock_client = Mock() + integration = PosthogCeleryIntegration( + client=mock_client, task_filter=lambda task_name, properties: False + ) + + integration._on_after_task_publish( + sender="app.tasks.filtered", + headers={"id": "task-3"}, + exchange="celery", + routing_key="default", + ) + + mock_client.capture.assert_not_called() + + def test_after_task_publish_skips_when_lifecycle_events_disabled(self): + mock_client = Mock() + integration = PosthogCeleryIntegration( + client=mock_client, + capture_task_lifecycle_events=False, + ) + + integration._on_after_task_publish( + sender="app.tasks.published", + headers={"id": "task-3"}, + exchange="celery", + routing_key="default", + ) + + mock_client.capture.assert_not_called() + + def test_capture_event_falls_back_to_global_capture(self): + integration = PosthogCeleryIntegration(client=None) + + with patch("posthog.capture") as mock_capture: + integration._capture_event("celery task started", properties={"celery_task_id": "t1"}) + + mock_capture.assert_called_once_with( + "celery task started", properties={"celery_task_id": "t1"} + ) + + def test_capture_exception_falls_back_to_global_capture_exception(self): + integration = PosthogCeleryIntegration(client=None) + exception = ValueError("boom") + + with patch("posthog.capture_exception") as mock_capture_exception: + integration._capture_exception(exception) + + mock_capture_exception.assert_called_once_with(exception) + + def test_extract_headers_supports_request_dict_shape(self): + integration = PosthogCeleryIntegration() + request = {"headers": {CONTEXT_DISTINCT_ID_HEADER: "user-1"}} + + headers = integration._extract_headers(request) + + self.assertEqual(headers, {CONTEXT_DISTINCT_ID_HEADER: "user-1"}) + + def test_prerun_exits_context_on_failure_after_entry(self): + mock_client = Mock() + integration = PosthogCeleryIntegration(client=mock_client) + + request = SimpleNamespace( + headers={}, + delivery_info={}, + hostname="worker-1", + retries=0, + ) + task = SimpleNamespace(name="app.tasks.boom", request=request) + + ctx_before = contexts._get_current_context() + + with patch.object(integration, "_apply_propagated_identity", side_effect=RuntimeError("boom")): + integration._on_task_prerun(sender=task, task_id="task-leak") + + ctx_after = contexts._get_current_context() + self.assertIs(ctx_after, ctx_before) + + def test_extract_propagated_tags_invalid_json_returns_empty_dict(self): + integration = PosthogCeleryIntegration() + request = SimpleNamespace(headers={CONTEXT_TAGS_HEADER: "{bad json"}) + + tags = integration._extract_propagated_tags(request) + + self.assertEqual(tags, {}) diff --git a/posthog/test/test_client.py b/posthog/test/test_client.py index 57c06867..81dd6f55 100644 --- a/posthog/test/test_client.py +++ b/posthog/test/test_client.py @@ -1,3 +1,6 @@ +import os +import subprocess +import sys import time import unittest from datetime import datetime @@ -2726,6 +2729,49 @@ def test_get_all_flags_and_payloads_with_empty_string(self, patch_batch_post): result["featureFlagPayloads"]["normal-payload-flag"], "normal payload" ) + @mock.patch("posthog.client.os.register_at_fork") + def test_registers_at_fork_hook(self, mock_register_at_fork): + client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail) + + mock_register_at_fork.assert_called_once() + after_in_child = mock_register_at_fork.call_args.kwargs["after_in_child"] + + with mock.patch.object(client, "_reinit_after_fork") as mock_reinit: + after_in_child() + mock_reinit.assert_called_once() + + @mock.patch("posthog.client.os.register_at_fork") + def test_register_at_fork_noop_after_client_gc(self, mock_register_at_fork): + client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail) + after_in_child = mock_register_at_fork.call_args.kwargs["after_in_child"] + del client + after_in_child() + + @parameterized.expand([(True, 1), (False, 0)]) + def test_reinit_after_fork_replaces_queue_and_consumers(self, send, expected_starts): + with mock.patch("posthog.client.Consumer.start") as mock_start: + client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail, send=send, thread=1) + mock_start.reset_mock() + + old_queue = client.queue + old_consumers = list(client.consumers) + + client._reinit_after_fork() + + self.assertIsNot(client.queue, old_queue) + self.assertEqual(len(client.consumers), len(old_consumers)) + self.assertIsNot(client.consumers[0], old_consumers[0]) + self.assertIs(client.consumers[0].queue, client.queue) + self.assertEqual(mock_start.call_count, expected_starts) + + def test_reinit_after_fork_noop_for_sync_mode(self): + client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail, sync_mode=True) + old_queue = client.queue + + client._reinit_after_fork() + + self.assertIs(client.queue, old_queue) + def test_context_tags_added(self): with mock.patch("posthog.client.batch_post") as mock_post: client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail, sync_mode=True) @@ -2780,3 +2826,52 @@ def test_debug_flag_re_raises_exceptions(self, mock_enqueue): with self.assertRaises(Exception) as cm: method(*args, **kwargs) self.assertEqual(str(cm.exception), "Expected error") + + +@unittest.skipUnless( + hasattr(os, "fork") and hasattr(os, "register_at_fork"), + "requires os.fork and os.register_at_fork", +) +class TestClientForkBehavior(unittest.TestCase): + def test_register_at_fork_reinitializes_client_in_child_process(self): + script = f""" +import os +from posthog.client import Client + +client = Client("{FAKE_TEST_API_KEY}", send=False) +old_queue = client.queue +old_consumer = client.consumers[0] + +read_fd, write_fd = os.pipe() +pid = os.fork() +if pid == 0: + try: + reinitialized_queue = client.queue is not old_queue + replaced_consumer = client.consumers[0] is not old_consumer + consumer_points_to_new_queue = client.consumers[0].queue is client.queue + child_ok = reinitialized_queue and replaced_consumer and consumer_points_to_new_queue + os.write(write_fd, b"1" if child_ok else b"0") + finally: + os.close(write_fd) + os.close(read_fd) + os._exit(0) + +os.close(write_fd) +result = os.read(read_fd, 1) +os.close(read_fd) +_, status = os.waitpid(pid, 0) +if os.WIFEXITED(status) and os.WEXITSTATUS(status) == 0 and result == b"1": + print("ok") +else: + raise SystemExit(1) +""" + proc = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + + self.assertEqual(proc.returncode, 0, msg=proc.stderr) + self.assertEqual(proc.stdout.strip(), "ok")