Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,12 @@ def __init__(
)
),
resource=resource,
metric_readers=metric_readers,
views=views,
)
self._metric_readers = metric_readers
self._measurement_consumer = SynchronousMeasurementConsumer(
sdk_config=self._sdk_config
sdk_config=self._sdk_config,
metric_readers=metric_readers,
)
disabled = environ.get(OTEL_SDK_DISABLED, "")
self._disabled = disabled.lower().strip() == "true"
Expand All @@ -456,7 +457,7 @@ def __init__(
self._shutdown_once = Once()
self._shutdown = False

for metric_reader in self._sdk_config.metric_readers:
for metric_reader in self._metric_readers:
with self._all_metric_readers_lock:
if metric_reader in self._all_metric_readers:
# pylint: disable=broad-exception-raised
Expand All @@ -476,7 +477,7 @@ def force_flush(self, timeout_millis: float = 10_000) -> bool:

metric_reader_error = {}

for metric_reader in self._sdk_config.metric_readers:
for metric_reader in self._metric_readers:
current_ts = time_ns()
try:
if current_ts >= deadline_ns:
Expand Down Expand Up @@ -521,7 +522,7 @@ def _shutdown():

metric_reader_error = {}

for metric_reader in self._sdk_config.metric_readers:
for metric_reader in self._metric_readers:
current_ts = time_ns()
try:
if current_ts >= deadline_ns:
Expand Down Expand Up @@ -588,3 +589,33 @@ def get_meter(
self._measurement_consumer,
)
return self._meters[info]

def add_metric_reader(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xrmx Any concerns on new public API here?

self, metric_reader: "opentelemetry.sdk.metrics.export.MetricReader"
) -> None:
with self._all_metric_readers_lock:
if metric_reader in self._all_metric_readers:
raise ValueError(
f"MetricReader {metric_reader} has been registered already!"
)
self._measurement_consumer.add_metric_reader(metric_reader)
# pylint: disable-next=protected-access
metric_reader._set_collect_callback(
self._measurement_consumer.collect
)
self._all_metric_readers.add(metric_reader)

def remove_metric_reader(
self,
metric_reader: "opentelemetry.sdk.metrics.export.MetricReader",
) -> None:
with self._all_metric_readers_lock:
if metric_reader not in self._all_metric_readers:
raise ValueError(
f"MetricReader {metric_reader} has not been registered!"
)
self._measurement_consumer.remove_metric_reader(metric_reader)
# pylint: disable-next=protected-access
metric_reader._set_collect_callback(None)
metric_reader.shutdown()
self._all_metric_readers.remove(metric_reader)
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

# pylint: disable=unused-import

import weakref
from abc import ABC, abstractmethod
from threading import Lock
from time import time_ns
from typing import List, Mapping, Optional
from typing import Iterable, List, Mapping, Optional

# This kind of import is needed to avoid Sphinx errors.
import opentelemetry.sdk.metrics
import opentelemetry.sdk.metrics._internal.instrument
import opentelemetry.sdk.metrics._internal.sdk_configuration
from opentelemetry.metrics._internal.instrument import CallbackOptions
from opentelemetry.sdk.metrics._internal.exceptions import MetricsTimeoutError
from opentelemetry.sdk.metrics._internal.measurement import Measurement
Expand Down Expand Up @@ -59,10 +59,10 @@ class SynchronousMeasurementConsumer(MeasurementConsumer):
def __init__(
self,
sdk_config: "opentelemetry.sdk.metrics._internal.SdkConfiguration",
metric_readers: Iterable["opentelemetry.sdk.metrics.MetricReader"],
) -> None:
self._lock = Lock()
self._sdk_config = sdk_config
# should never be mutated
self._reader_storages: Mapping[
"opentelemetry.sdk.metrics.MetricReader", MetricReaderStorage
] = {
Expand All @@ -71,7 +71,7 @@ def __init__(
reader._instrument_class_temporality,
reader._instrument_class_aggregation,
)
for reader in sdk_config.metric_readers
for reader in metric_readers
}
self._async_instruments: List[
"opentelemetry.sdk.metrics._internal.instrument._Asynchronous"
Expand All @@ -86,7 +86,9 @@ def consume_measurement(self, measurement: Measurement) -> None:
measurement.context,
)
)
for reader_storage in self._reader_storages.values():
with self._lock:
reader_storages = weakref.WeakSet(self._reader_storages.values())
for reader_storage in reader_storages:
reader_storage.consume_measurement(
measurement, should_sample_exemplar
)
Expand Down Expand Up @@ -143,3 +145,23 @@ def collect(
result = self._reader_storages[metric_reader].collect()

return result

def add_metric_reader(
self, metric_reader: "opentelemetry.sdk.metrics.MetricReader"
) -> None:
"""Registers a new metric reader."""
with self._lock:
self._reader_storages[metric_reader] = MetricReaderStorage(
self._sdk_config,
# pylint: disable-next=protected-access
metric_reader._instrument_class_temporality,
# pylint: disable-next=protected-access
metric_reader._instrument_class_aggregation,
)

def remove_metric_reader(
self, metric_reader: "opentelemetry.sdk.metrics.MetricReader"
) -> None:
"""Unregisters the given metric reader."""
with self._lock:
self._reader_storages.pop(metric_reader)
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@
class SdkConfiguration:
exemplar_filter: "opentelemetry.sdk.metrics.ExemplarFilter"
resource: "opentelemetry.sdk.resources.Resource"
metric_readers: Sequence["opentelemetry.sdk.metrics.MetricReader"]
views: Sequence["opentelemetry.sdk.metrics.View"]
114 changes: 101 additions & 13 deletions opentelemetry-sdk/tests/metrics/test_measurement_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# pylint: disable=invalid-name,no-self-use

from threading import Event, Thread
from time import sleep
from unittest import TestCase
from unittest.mock import MagicMock, Mock, patch
Expand All @@ -34,7 +35,8 @@
class TestSynchronousMeasurementConsumer(TestCase):
def test_parent(self, _):
self.assertIsInstance(
SynchronousMeasurementConsumer(MagicMock()), MeasurementConsumer
SynchronousMeasurementConsumer(MagicMock(), metric_readers=()),
MeasurementConsumer,
)

def test_creates_metric_reader_storages(self, MockMetricReaderStorage):
Expand All @@ -44,9 +46,9 @@ def test_creates_metric_reader_storages(self, MockMetricReaderStorage):
SdkConfiguration(
exemplar_filter=Mock(),
resource=Mock(),
metric_readers=reader_mocks,
views=Mock(),
)
),
metric_readers=reader_mocks,
)
self.assertEqual(len(MockMetricReaderStorage.mock_calls), 5)

Expand All @@ -61,9 +63,9 @@ def test_measurements_passed_to_each_reader_storage(
SdkConfiguration(
exemplar_filter=Mock(should_sample=Mock(return_value=False)),
resource=Mock(),
metric_readers=reader_mocks,
views=Mock(),
)
),
metric_readers=reader_mocks,
)
measurement_mock = Mock()
consumer.consume_measurement(measurement_mock)
Expand All @@ -83,9 +85,9 @@ def test_collect_passed_to_reader_stage(self, MockMetricReaderStorage):
SdkConfiguration(
exemplar_filter=Mock(),
resource=Mock(),
metric_readers=reader_mocks,
views=Mock(),
)
),
metric_readers=reader_mocks,
)
for r_mock, rs_mock in zip(reader_mocks, reader_storage_mocks):
rs_mock.collect.assert_not_called()
Expand All @@ -102,9 +104,9 @@ def test_collect_calls_async_instruments(self, MockMetricReaderStorage):
SdkConfiguration(
exemplar_filter=Mock(should_sample=Mock(return_value=False)),
resource=Mock(),
metric_readers=[reader_mock],
views=Mock(),
)
),
metric_readers=[reader_mock],
)
async_instrument_mocks = [MagicMock() for _ in range(5)]
for i_mock in async_instrument_mocks:
Expand Down Expand Up @@ -133,9 +135,9 @@ def test_collect_timeout(self, MockMetricReaderStorage):
SdkConfiguration(
exemplar_filter=Mock(),
resource=Mock(),
metric_readers=[reader_mock],
views=Mock(),
)
),
metric_readers=[reader_mock],
)

def sleep_1(*args, **kwargs):
Expand Down Expand Up @@ -166,9 +168,9 @@ def test_collect_deadline(
SdkConfiguration(
exemplar_filter=Mock(),
resource=Mock(),
metric_readers=[reader_mock],
views=Mock(),
)
),
metric_readers=[reader_mock],
)

def sleep_1(*args, **kwargs):
Expand All @@ -192,3 +194,89 @@ def sleep_1(*args, **kwargs):
callback_options_time_call,
10000,
)


class TestSynchronousMeasurementConsumerConcurrency(TestCase):
def test_concurrent_changes_to_metric_readers(self):
timeout = 1
failure = None
iteration_started = Event()
mutation_done = Event()
iteration_timeout_error = "Timed out waiting for iteration to start"
mutation_timeout_error = "Timed out waiting for mutation to be done"

consumer = SynchronousMeasurementConsumer(
SdkConfiguration(
exemplar_filter=MagicMock(),
resource=MagicMock(),
views=MagicMock(),
),
metric_readers=[MagicMock()],
)

def _hooked_iter(iterable):
nonlocal failure

iterable = iter(iterable)
iteration_started.set()
if not mutation_done.wait(timeout):
failure = mutation_timeout_error
yield next(iterable, None)
yield from iterable

class HookedDict(dict):
def __iter__(self):
return _hooked_iter(super().__iter__())

def keys(self):
return _hooked_iter(super().keys())

def values(self):
return _hooked_iter(super().values())

def items(self):
return _hooked_iter(super().items())

with patch.object(
# pylint: disable-next=protected-access
consumer, "_reader_storages", HookedDict(consumer._reader_storages)
):

def mutate():
"""Directly mutate _reader_storages after iteration starts"""
nonlocal failure
if not iteration_started.wait(timeout):
failure = iteration_timeout_error
#pylint: disable-next=protected-access
consumer._reader_storages.clear()

# Verify that test setup works (direct mutation with no synchronization fails)
with self.assertRaises(RuntimeError) as cm:
t = Thread(target=mutate)
t.start()
try:
consumer.consume_measurement(MagicMock())
finally:
t.join()
self.assertEqual(
"dictionary changed size during iteration", str(cm.exception)
)

def add_and_remove_readers():
"""Modifies _reader_storages after iteration starts"""
nonlocal failure
if not iteration_started.wait(timeout):
failure = iteration_timeout_error
reader = MagicMock()
consumer.add_metric_reader(reader)
consumer.remove_metric_reader(reader)

# Verify the API calls do not attempt concurrent modification of reader storages
t = Thread(target=add_and_remove_readers)
t.start()
try:
consumer.add_metric_reader(MagicMock())
consumer.consume_measurement(MagicMock())
finally:
t.join()
self.assertEqual(mutation_timeout_error, failure)
Loading
Loading