Skip to content

Commit d0abacf

Browse files
committed
Wrap strean response in proxy that manage span lifetime
1 parent 4eb24c2 commit d0abacf

4 files changed

Lines changed: 131 additions & 48 deletions

File tree

instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121

2222
import logging
2323
from collections import OrderedDict
24-
from functools import partial
2524
from typing import Callable, MutableMapping
2625

2726
import grpc
27+
import wrapt
28+
2829
from opentelemetry import context, trace
2930
from opentelemetry.instrumentation.grpc import grpcext
3031
from opentelemetry.instrumentation.grpc._utilities import RpcInfo
@@ -72,11 +73,59 @@ def _safe_invoke(function: Callable, *args):
7273
"Error when invoking function '%s'", function_name, exc_info=ex
7374
)
7475

76+
77+
class OpenTelemetryStreamWrapper(wrapt.ObjectProxy):
78+
def __init__(self, wrapped, span: trace.Span):
79+
super().__init__(wrapped)
80+
self._self_span = span
81+
82+
def _end_span_if_not_already_ended(self, status_code=None, status=None):
83+
if self._self_span.end_time is None:
84+
self._self_span.end()
85+
if status_code is not None:
86+
self._self_span.set_attribute(
87+
SpanAttributes.RPC_GRPC_STATUS_CODE, status_code
88+
)
89+
if status is not None:
90+
self._self_span.set_status(status)
91+
92+
def __del__(self):
93+
self._end_span_if_not_already_ended()
94+
self.__wrapped__.__del__()
95+
96+
def __iter__(self):
97+
return self
98+
99+
def cancel(self):
100+
self._end_span_if_not_already_ended(
101+
status_code=grpc.StatusCode.CANCELLED.value[0]
102+
)
103+
return self.__wrapped__.cancel()
104+
105+
def __next__(self):
106+
return self._next()
107+
108+
def next(self):
109+
return self._next()
110+
111+
def _next(self):
112+
try:
113+
return self.__wrapped__._next()
114+
except StopIteration:
115+
self._end_span_if_not_already_ended()
116+
raise
117+
except grpc.RpcError as err:
118+
self._end_span_if_not_already_ended(
119+
err.code().value[0], Status(StatusCode.ERROR)
120+
)
121+
raise err
122+
123+
75124
class OpenTelemetryClientInterceptor(
76125
grpcext.UnaryClientInterceptor, grpcext.StreamClientInterceptor
77126
):
78127
def __init__(
79-
self, tracer, filter_=None, request_hook=None, response_hook=None
128+
self, tracer, filter_=None, request_hook=None, response_hook=None
80129
):
81130
self._tracer = tracer
82131
self._filter = filter_
@@ -130,10 +179,10 @@ def _intercept(self, request, metadata, client_info, invoker):
130179
else:
131180
mutable_metadata = OrderedDict(metadata)
132181
with self._start_span(
133-
client_info.full_method,
134-
end_on_exit=False,
135-
record_exception=False,
136-
set_status_on_exception=False,
182+
client_info.full_method,
183+
end_on_exit=False,
184+
record_exception=False,
185+
set_status_on_exception=False,
137186
) as span:
138187
result = None
139188
try:
@@ -187,16 +236,15 @@ def intercept_unary(self, request, metadata, client_info, invoker):
187236
# the span across the generated responses and detect any errors, we wrap
188237
# the result in a new generator that yields the response values.
189238
def _intercept_server_stream(
190-
self, request_or_iterator, metadata, client_info, invoker
239+
self, request_or_iterator, metadata, client_info, invoker
191240
):
192241
if not metadata:
193242
mutable_metadata = OrderedDict()
194243
else:
195244
mutable_metadata = OrderedDict(metadata)
196245

197246
with self._start_span(
198-
client_info.full_method,
199-
end_on_exit=False
247+
client_info.full_method, end_on_exit=False
200248
) as span:
201249
inject(mutable_metadata, setter=_carrier_setter)
202250
metadata = tuple(mutable_metadata.items())
@@ -211,27 +259,10 @@ def _intercept_server_stream(
211259

212260
stream = invoker(request_or_iterator, metadata)
213261

214-
def done_callback(future, span_):
215-
try:
216-
future.result()
217-
except grpc.FutureCancelledError:
218-
span_.set_status(Status(StatusCode.OK))
219-
span_.set_attribute(
220-
SpanAttributes.RPC_GRPC_STATUS_CODE, grpc.StatusCode.CANCELLED.value[0]
221-
)
222-
except grpc.RpcError as err:
223-
span_.set_status(Status(StatusCode.ERROR))
224-
span_.set_attribute(
225-
SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0]
226-
)
227-
finally:
228-
span_.end()
229-
230-
stream.add_done_callback(partial(done_callback, span_=span))
231-
return stream
262+
return OpenTelemetryStreamWrapper(stream, span)
232263

233264
def intercept_stream(
234-
self, request_or_iterator, metadata, client_info, invoker
265+
self, request_or_iterator, metadata, client_info, invoker
235266
):
236267
if context.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
237268
return invoker(request_or_iterator, metadata)

instrumentation/opentelemetry-instrumentation-grpc/tests/_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def server_streaming_method(stub, error=False, serialize=True):
5353
return response_iterator
5454

5555

56-
def bidirectional_streaming_method(stub, error=False):
56+
def bidirectional_streaming_method(stub, error=False, serialize=True):
5757
def request_messages():
5858
for _ in range(5):
5959
request = Request(
@@ -62,5 +62,6 @@ def request_messages():
6262
yield request
6363

6464
response_iterator = stub.BidirectionalStreamingMethod(request_messages())
65-
66-
list(response_iterator)
65+
if serialize:
66+
list(response_iterator)
67+
return response_iterator

instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# pylint:disable=cyclic-import
15-
from time import sleep
1615

1716
import grpc
1817
from tests.protobuf import ( # pylint: disable=no-name-in-module
@@ -56,32 +55,32 @@ def __init__(self):
5655
pass
5756

5857
def intercept_unary_unary(
59-
self, continuation, client_call_details, request
58+
self, continuation, client_call_details, request
6059
):
6160
return self._intercept_call(continuation, client_call_details, request)
6261

6362
def intercept_unary_stream(
64-
self, continuation, client_call_details, request
63+
self, continuation, client_call_details, request
6564
):
6665
return self._intercept_call(continuation, client_call_details, request)
6766

6867
def intercept_stream_unary(
69-
self, continuation, client_call_details, request_iterator
68+
self, continuation, client_call_details, request_iterator
7069
):
7170
return self._intercept_call(
7271
continuation, client_call_details, request_iterator
7372
)
7473

7574
def intercept_stream_stream(
76-
self, continuation, client_call_details, request_iterator
75+
self, continuation, client_call_details, request_iterator
7776
):
7877
return self._intercept_call(
7978
continuation, client_call_details, request_iterator
8079
)
8180

8281
@staticmethod
8382
def _intercept_call(
84-
continuation, client_call_details, request_or_iterator
83+
continuation, client_call_details, request_or_iterator
8584
):
8685
return continuation(client_call_details, request_or_iterator)
8786

@@ -94,9 +93,7 @@ def setUp(self):
9493
self.server.start()
9594
# use a user defined interceptor along with the opentelemetry client interceptor
9695
interceptors = [Interceptor()]
97-
self.channel = grpc.insecure_channel("localhost:25565", options=[
98-
# (grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1)
99-
])
96+
self.channel = grpc.insecure_channel("localhost:25565")
10097
self.channel = grpc.intercept_channel(self.channel, *interceptors)
10198
self._stub = test_server_pb2_grpc.GRPCTestServerStub(self.channel)
10299

@@ -173,14 +170,11 @@ def test_unary_stream(self):
173170
)
174171

175172
def test_unary_stream_can_be_cancel(self):
176-
responses = server_streaming_method(self._stub)
173+
responses = server_streaming_method(self._stub, serialize=False)
177174
for i, _ in enumerate(responses):
178175
if i == 1:
179176
responses.cancel()
180177
break
181-
sleep(10)
182-
self.server.stop(None)
183-
self.channel.close()
184178
spans = self.memory_exporter.get_finished_spans()
185179
self.assertEqual(len(spans), 1)
186180
span = spans[0]
@@ -205,6 +199,33 @@ def test_unary_stream_can_be_cancel(self):
205199
},
206200
)
207201

202+
def test_finished_stream_cancel_does_not_change_status_of_span(self):
203+
responses = server_streaming_method(self._stub, serialize=True)
204+
responses.cancel()
205+
spans = self.memory_exporter.get_finished_spans()
206+
self.assertEqual(len(spans), 1)
207+
span = spans[0]
208+
209+
self.assertEqual(span.name, "/GRPCTestServer/ServerStreamingMethod")
210+
self.assertIs(span.kind, trace.SpanKind.CLIENT)
211+
212+
# Check version and name in span's instrumentation info
213+
self.assertEqualSpanInstrumentationInfo(
214+
span, opentelemetry.instrumentation.grpc
215+
)
216+
217+
self.assertSpanHasAttributes(
218+
span,
219+
{
220+
SpanAttributes.RPC_METHOD: "ServerStreamingMethod",
221+
SpanAttributes.RPC_SERVICE: "GRPCTestServer",
222+
SpanAttributes.RPC_SYSTEM: "grpc",
223+
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[
224+
0
225+
],
226+
},
227+
)
228+
208229
def test_stream_unary(self):
209230
client_streaming_method(self._stub)
210231
spans = self.memory_exporter.get_finished_spans()
@@ -259,6 +280,38 @@ def test_stream_stream(self):
259280
},
260281
)
261282

283+
def test_stream_stream_can_be_cancel(self):
284+
responses = bidirectional_streaming_method(self._stub, serialize=False)
285+
for i, _ in enumerate(responses):
286+
if i == 1:
287+
responses.cancel()
288+
break
289+
spans = self.memory_exporter.get_finished_spans()
290+
self.assertEqual(len(spans), 1)
291+
span = spans[0]
292+
293+
self.assertEqual(
294+
span.name, "/GRPCTestServer/BidirectionalStreamingMethod"
295+
)
296+
self.assertIs(span.kind, trace.SpanKind.CLIENT)
297+
298+
# Check version and name in span's instrumentation info
299+
self.assertEqualSpanInstrumentationInfo(
300+
span, opentelemetry.instrumentation.grpc
301+
)
302+
303+
self.assertSpanHasAttributes(
304+
span,
305+
{
306+
SpanAttributes.RPC_METHOD: "BidirectionalStreamingMethod",
307+
SpanAttributes.RPC_SERVICE: "GRPCTestServer",
308+
SpanAttributes.RPC_SYSTEM: "grpc",
309+
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.CANCELLED.value[
310+
0
311+
],
312+
},
313+
)
314+
262315
def test_error_simple(self):
263316
with self.assertRaises(grpc.RpcError):
264317
simple_method(self._stub, error=True)
@@ -308,7 +361,7 @@ def test_error_stream_stream(self):
308361
)
309362

310363
def test_client_interceptor_trace_context_propagation(
311-
self,
364+
self,
312365
): # pylint: disable=no-self-use
313366
"""ensure that client interceptor correctly inject trace context into all outgoing requests."""
314367
previous_propagator = get_global_textmap()

instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor_filter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def setUp(self):
9898
self.server.start()
9999
# use a user defined interceptor along with the opentelemetry client interceptor
100100
interceptors = [Interceptor()]
101-
self.channel = grpc.insecure_channel("localhost:25565",options=[
102-
(grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1)
103-
])
101+
self.channel = grpc.insecure_channel("localhost:25565")
104102
self.channel = grpc.intercept_channel(self.channel, *interceptors)
105103
self._stub = test_server_pb2_grpc.GRPCTestServerStub(self.channel)
106104

0 commit comments

Comments
 (0)