1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import Mock, patch
from opentelemetry import context, trace
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.trace.status import Status, StatusCode
class SpanTest(trace.NonRecordingSpan):
has_ended = False
recorded_exception = None
recorded_status = Status(status_code=StatusCode.UNSET)
def set_status(self, status, description=None):
if isinstance(status, Status):
self.recorded_status = status
else:
self.recorded_status = Status(
status_code=status, description=description
)
def end(self, end_time=None):
self.has_ended = True
def is_recording(self):
return not self.has_ended
def record_exception(
self, exception, attributes=None, timestamp=None, escaped=False
):
self.recorded_exception = exception
class TestGlobals(TraceGlobalsTest, unittest.TestCase):
@staticmethod
@patch("opentelemetry.trace._TRACER_PROVIDER")
def test_get_tracer(mock_tracer_provider): # type: ignore
"""trace.get_tracer should proxy to the global tracer provider."""
trace.get_tracer("foo", "var")
mock_tracer_provider.get_tracer.assert_called_with(
"foo", "var", None, None
)
mock_provider = Mock()
trace.get_tracer("foo", "var", mock_provider)
mock_provider.get_tracer.assert_called_with("foo", "var", None, None)
class TestGlobalsConcurrency(TraceGlobalsTest, ConcurrencyTestBase):
@patch("opentelemetry.trace.logger")
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
mock_logger.warning = MockFunc()
def do_concurrently() -> Mock:
# first get a proxy tracer
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")
# try to set the global tracer provider
mock_tracer_provider = Mock(get_tracer=MockFunc())
trace.set_tracer_provider(mock_tracer_provider)
# start a span through the proxy which will call through to the mock provider
proxy_tracer.start_span("foo")
return mock_tracer_provider
num_threads = 100
mock_tracer_providers = self.run_with_many_threads(
do_concurrently,
num_threads=num_threads,
)
# despite trying to set tracer provider many times, only one of the
# mock_tracer_providers should have stuck and been called from
# proxy_tracer.start_span()
mock_tps_with_any_call = [
mock
for mock in mock_tracer_providers
if mock.get_tracer.call_count > 0
]
self.assertEqual(len(mock_tps_with_any_call), 1)
self.assertEqual(
mock_tps_with_any_call[0].get_tracer.call_count, num_threads
)
# should have warned every time except for the successful set
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)
class TestTracer(unittest.TestCase):
def setUp(self):
self.tracer = trace.NoOpTracer()
def test_get_current_span(self):
"""NoOpTracer's start_span will also
be retrievable via get_current_span
"""
self.assertEqual(trace.get_current_span(), trace.INVALID_SPAN)
span = trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)
ctx = trace.set_span_in_context(span)
token = context.attach(ctx)
try:
self.assertIs(trace.get_current_span(), span)
finally:
context.detach(token)
self.assertEqual(trace.get_current_span(), trace.INVALID_SPAN)
class TestUseTracer(unittest.TestCase):
def test_use_span(self):
self.assertEqual(trace.get_current_span(), trace.INVALID_SPAN)
span = trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)
with trace.use_span(span):
self.assertIs(trace.get_current_span(), span)
self.assertEqual(trace.get_current_span(), trace.INVALID_SPAN)
def test_use_span_end_on_exit(self):
test_span = SpanTest(trace.INVALID_SPAN_CONTEXT)
with trace.use_span(test_span):
pass
self.assertFalse(test_span.has_ended)
with trace.use_span(test_span, end_on_exit=True):
pass
self.assertTrue(test_span.has_ended)
def test_use_span_exception(self):
class TestUseSpanException(Exception):
pass
test_span = SpanTest(trace.INVALID_SPAN_CONTEXT)
exception = TestUseSpanException("test exception")
with self.assertRaises(TestUseSpanException):
with trace.use_span(test_span):
raise exception
self.assertEqual(test_span.recorded_exception, exception)
def test_use_span_set_status(self):
class TestUseSpanException(Exception):
pass
test_span = SpanTest(trace.INVALID_SPAN_CONTEXT)
with self.assertRaises(TestUseSpanException):
with trace.use_span(test_span):
raise TestUseSpanException("test error")
self.assertEqual(
test_span.recorded_status.status_code,
StatusCode.ERROR,
)
self.assertEqual(
test_span.recorded_status.description,
"TestUseSpanException: test error",
)
def test_use_span_base_exceptions(self):
base_exception_classes = [
BaseException,
GeneratorExit,
SystemExit,
KeyboardInterrupt,
]
for exc_cls in base_exception_classes:
with self.subTest(exc=exc_cls.__name__):
test_span = SpanTest(trace.INVALID_SPAN_CONTEXT)
with self.assertRaises(exc_cls):
with trace.use_span(test_span):
raise exc_cls()
self.assertEqual(
test_span.recorded_status.status_code,
StatusCode.UNSET,
)
self.assertIsNone(test_span.recorded_status.description)
self.assertIsNone(test_span.recorded_exception)
|