Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
#: W3C Trace Context header names used for distributed trace propagation.
_W3C_HEADERS = ("traceparent", "tracestate")

#: Baggage key whose value overrides the parent span ID.
_LEAF_CUSTOMER_SPAN_ID = "leaf_customer_span_id"

# ------------------------------------------------------------------
# GenAI semantic convention attribute keys
Expand Down Expand Up @@ -162,36 +160,21 @@ def __init__(
def _extract_context(
self,
carrier: Optional[dict[str, str]],
baggage_header: Optional[str] = None,
) -> Any:
"""Extract parent trace context from a W3C carrier dict.

When a ``baggage`` header is provided and contains a
``leaf_customer_span_id`` key, the parent span ID is overridden
so that the server's root span is parented under the leaf customer
span rather than the span referenced in the ``traceparent`` header.
Uses the standard ``traceparent`` / ``tracestate`` headers via
the OpenTelemetry :class:`TraceContextTextMapPropagator`.

:param carrier: W3C trace-context headers or None.
:type carrier: Optional[dict[str, str]]
:param baggage_header: Raw ``baggage`` header value or None.
:type baggage_header: Optional[str]
:return: The extracted OTel context, or None.
:rtype: Any
"""
if not carrier or self._propagator is None:
return None

ctx = self._propagator.extract(carrier=carrier)

if not baggage_header:
return ctx

leaf_span_id = _parse_baggage_key(
baggage_header, _LEAF_CUSTOMER_SPAN_ID)
if not leaf_span_id:
return ctx

return _override_parent_span_id(ctx, leaf_span_id)
return self._propagator.extract(carrier=carrier)

@staticmethod
def _setup_azure_monitor(connection_string: str, resource: Any, trace_provider: Any) -> None:
Expand Down Expand Up @@ -305,7 +288,6 @@ def span(
name: str,
attributes: Optional[dict[str, str]] = None,
carrier: Optional[dict[str, str]] = None,
baggage_header: Optional[str] = None,
) -> Iterator[Any]:
"""Create a traced span if tracing is enabled, otherwise no-op.

Expand All @@ -319,17 +301,14 @@ def span(
:type attributes: Optional[dict[str, str]]
:param carrier: Incoming HTTP headers for W3C trace-context propagation.
:type carrier: Optional[dict[str, str]]
:param baggage_header: Raw ``baggage`` header value for
``leaf_customer_span_id`` extraction.
:type baggage_header: Optional[str]
:return: Context manager that yields the OTel span or *None*.
:rtype: Iterator[Any]
"""
if not self._enabled or self._tracer is None:
yield None
return

ctx = self._extract_context(carrier, baggage_header)
ctx = self._extract_context(carrier)

with self._tracer.start_as_current_span(
name=name,
Expand All @@ -344,7 +323,6 @@ def start_span(
name: str,
attributes: Optional[dict[str, str]] = None,
carrier: Optional[dict[str, str]] = None,
baggage_header: Optional[str] = None,
) -> Any:
"""Start a span without a context manager.

Expand All @@ -358,16 +336,13 @@ def start_span(
:type attributes: Optional[dict[str, str]]
:param carrier: Incoming HTTP headers for W3C trace-context propagation.
:type carrier: Optional[dict[str, str]]
:param baggage_header: Raw ``baggage`` header value for
``leaf_customer_span_id`` extraction.
:type baggage_header: Optional[str]
:return: The OTel span, or *None* when tracing is disabled.
:rtype: Any
"""
if not self._enabled or self._tracer is None:
return None

ctx = self._extract_context(carrier, baggage_header)
ctx = self._extract_context(carrier)

return self._tracer.start_span(
name=name,
Expand All @@ -387,7 +362,7 @@ def _prepare_request_span_args(
span_operation: str,
operation_name: Optional[str] = None,
session_id: str = "",
) -> tuple[str, dict[str, str], dict[str, str], Optional[str]]:
) -> tuple[str, dict[str, str], dict[str, str]]:
"""Extract headers and build span arguments for a request.

Shared pipeline used by :meth:`start_request_span` and
Expand All @@ -405,16 +380,15 @@ def _prepare_request_span_args(
:param session_id: Session ID from the ``agent_session_id`` query
parameter. Defaults to ``""`` (no session).
:type session_id: str
:return: ``(name, attributes, carrier, baggage)`` ready for
:return: ``(name, attributes, carrier)`` ready for
:meth:`span` or :meth:`start_span`.
:rtype: tuple[str, dict[str, str], dict[str, str], Optional[str]]
:rtype: tuple[str, dict[str, str], dict[str, str]]
"""
carrier = _extract_w3c_carrier(headers)
baggage = headers.get("baggage")
span_attrs = self.build_span_attrs(
invocation_id, session_id, operation_name=operation_name
)
return self.span_name(span_operation), span_attrs, carrier, baggage
return self.span_name(span_operation), span_attrs, carrier

def start_request_span(
self,
Expand Down Expand Up @@ -449,11 +423,11 @@ def start_request_span(
:return: The OTel span, or *None* when tracing is disabled.
:rtype: Any
"""
name, attrs, carrier, baggage = self._prepare_request_span_args(
name, attrs, carrier = self._prepare_request_span_args(
headers, invocation_id, span_operation, operation_name,
session_id=session_id,
)
return self.start_span(name, attributes=attrs, carrier=carrier, baggage_header=baggage)
return self.start_span(name, attributes=attrs, carrier=carrier)

@contextmanager
def request_span(
Expand Down Expand Up @@ -488,11 +462,11 @@ def request_span(
:return: Context manager that yields the OTel span or *None*.
:rtype: Iterator[Any]
"""
name, attrs, carrier, baggage = self._prepare_request_span_args(
name, attrs, carrier = self._prepare_request_span_args(
headers, invocation_id, span_operation, operation_name,
session_id=session_id,
)
with self.span(name, attributes=attrs, carrier=carrier, baggage_header=baggage) as otel_span:
with self.span(name, attributes=attrs, carrier=carrier) as otel_span:
yield otel_span

# ------------------------------------------------------------------
Expand Down Expand Up @@ -907,95 +881,3 @@ def _extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]:
result: dict[str, str] = {k: v for k in _W3C_HEADERS if (
v := headers.get(k)) is not None}
return result


def _parse_baggage_key(baggage: str, key: str) -> Optional[str]:
"""Parse a single key from a W3C Baggage header value.

The `W3C Baggage`_ format is a comma-separated list of
``key=value`` pairs with optional properties after a ``;``.

Example::

leaf_customer_span_id=abc123,other=val

.. _W3C Baggage: https://www.w3.org/TR/baggage/

:param baggage: The raw header value.
:type baggage: str
:param key: The baggage key to look up.
:type key: str
:return: The value for *key*, or *None* if not found.
:rtype: Optional[str]
"""
for member in baggage.split(","):
member = member.strip()
if not member:
continue
# Split on first '=' only; value may contain '='
kv_part = member.split(";", 1)[0] # strip optional properties
eq_idx = kv_part.find("=")
if eq_idx < 0:
continue
k = kv_part[:eq_idx].strip()
v = kv_part[eq_idx + 1:].strip()
if k == key:
return v
return None


def _override_parent_span_id(ctx: Any, hex_span_id: str) -> Any:
"""Create a new context with the same trace ID but a different parent span ID.

Constructs a :class:`~opentelemetry.trace.SpanContext` with the trace ID
taken from the existing context and the span ID replaced by
*hex_span_id*. The resulting context can be used as the ``context``
argument to ``start_span`` / ``start_as_current_span``.

Returns the original *ctx* unchanged if *hex_span_id* is invalid or
``opentelemetry-api`` is not installed.

Per invocation-protocol-spec.

:param ctx: An OTel context produced by ``TraceContextTextMapPropagator.extract()``.
:type ctx: Any
:param hex_span_id: 16-character lower-case hex string representing the
desired parent span ID.
:type hex_span_id: str
:return: A context with the overridden parent span ID, or the original.
:rtype: Any
"""
if not _HAS_OTEL:
return ctx

# A valid OTel span ID is exactly 16 hex characters (8 bytes).
if len(hex_span_id) != 16:
logger.warning(
"Invalid leaf_customer_span_id length in baggage: %r (expected 16 hex chars)", hex_span_id)
return ctx

try:
new_span_id = int(hex_span_id, 16)
except (ValueError, TypeError):
logger.warning(
"Invalid leaf_customer_span_id in baggage: %r", hex_span_id)
return ctx

if new_span_id == 0:
return ctx

# Grab the trace ID from the current parent span in ctx.
current_span = trace.get_current_span(ctx)
current_ctx = current_span.get_span_context()
if current_ctx is None or not current_ctx.is_valid:
return ctx

custom_span_ctx = trace.SpanContext(
trace_id=current_ctx.trace_id,
span_id=new_span_id,
is_remote=True,
trace_flags=current_ctx.trace_flags,
trace_state=current_ctx.trace_state,
)
custom_parent = trace.NonRecordingSpan(custom_span_ctx)
return trace.set_span_in_context(custom_parent, ctx)
39 changes: 0 additions & 39 deletions sdk/agentserver/azure-ai-agentserver-core/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
resolve_appinsights_connection_string,
)
from azure.ai.agentserver.core._constants import Constants
from azure.ai.agentserver.core._tracing import _parse_baggage_key


# ------------------------------------------------------------------ #
Expand Down Expand Up @@ -184,41 +183,3 @@ def test_agent_version_default_empty(self) -> None:
env.pop(Constants.FOUNDRY_AGENT_VERSION, None)
with mock.patch.dict(os.environ, env, clear=True):
assert resolve_agent_version() == ""


# ------------------------------------------------------------------ #
# Baggage parsing (unit tests for _parse_baggage_key)
# ------------------------------------------------------------------ #


class TestParseBaggageKey:
"""Unit tests for _parse_baggage_key()."""

def test_single_key(self) -> None:
assert _parse_baggage_key("leaf_customer_span_id=abc123", "leaf_customer_span_id") == "abc123"

def test_multiple_keys(self) -> None:
baggage = "key1=val1,leaf_customer_span_id=def456,key2=val2"
assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "def456"

def test_key_not_found(self) -> None:
assert _parse_baggage_key("key1=val1,key2=val2", "leaf_customer_span_id") is None

def test_empty_baggage(self) -> None:
assert _parse_baggage_key("", "leaf_customer_span_id") is None

def test_key_with_properties(self) -> None:
baggage = "leaf_customer_span_id=abc123;prop1=x"
assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc123"

def test_whitespace_handling(self) -> None:
baggage = " leaf_customer_span_id = abc123 , other = val "
assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc123"

def test_value_with_equals(self) -> None:
baggage = "leaf_customer_span_id=abc=123"
assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc=123"

def test_no_equals_in_member(self) -> None:
baggage = "malformed_entry,leaf_customer_span_id=good"
assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "good"
Original file line number Diff line number Diff line change
Expand Up @@ -428,42 +428,6 @@ async def test_namespaced_invocation_id_attribute():
assert attrs.get("azure.ai.agentserver.invocations.invocation_id") == inv_id


# ---------------------------------------------------------------------------
# Baggage tests
# ---------------------------------------------------------------------------

@pytest.mark.asyncio
async def test_baggage_leaf_customer_span_id():
"""Baggage leaf_customer_span_id overrides parent span ID."""
server = _make_tracing_server()
transport = ASGITransport(app=server.app)

trace_id_hex = uuid.uuid4().hex
original_span_id = uuid.uuid4().hex[:16]
leaf_span_id = uuid.uuid4().hex[:16]
traceparent = f"00-{trace_id_hex}-{original_span_id}-01"
baggage = f"leaf_customer_span_id={leaf_span_id}"

async with AsyncClient(transport=transport, base_url="http://testserver") as client:
await client.post(
"/invocations",
content=b"test",
headers={
"traceparent": traceparent,
"baggage": baggage,
},
)

spans = _get_spans()
invoke_spans = [s for s in spans if "invoke_agent" in s.name]
assert len(invoke_spans) >= 1
span = invoke_spans[0]
# The parent span ID should be overridden to leaf_span_id
if span.parent is not None:
actual_parent_span_id = format(span.parent.span_id, "016x")
assert actual_parent_span_id == leaf_span_id


# ---------------------------------------------------------------------------
# Agent name/version in span names
# ---------------------------------------------------------------------------
Expand Down