Skip to content

Commit d08b82b

Browse files
committed
Refactor test cases
1 parent 76548b2 commit d08b82b

4 files changed

Lines changed: 184 additions & 40 deletions

File tree

ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def hello():
4747
---
4848
"""
4949

50-
import logging
50+
from logging import getLogger
5151

5252
import flask
5353

@@ -57,7 +57,7 @@ def hello():
5757
from opentelemetry.ext.flask.version import __version__
5858
from opentelemetry.util import time_ns
5959

60-
logger = logging.getLogger(__name__)
60+
logger = getLogger(__name__)
6161

6262
_ENVIRON_STARTTIME_KEY = "opentelemetry-flask.starttime_key"
6363
_ENVIRON_SPAN_KEY = "opentelemetry-flask.span_key"
@@ -96,8 +96,6 @@ def _start_response(status, response_headers, *args, **kwargs):
9696

9797

9898
def _before_request():
99-
from ipdb import set_trace
100-
set_trace
10199
environ = flask.request.environ
102100
span_name = (
103101
flask.request.endpoint
@@ -168,10 +166,11 @@ def _instrument(self, **kwargs):
168166
app = kwargs.get("app")
169167

170168
if app is None:
169+
self._original_flask = flask.Flask
171170
flask.Flask = _InstrumentedFlask
172171

173172
else:
174-
173+
app._original_wsgi_app = app.wsgi_app
175174
app.wsgi_app = _rewrapped_app(app.wsgi_app)
176175

177176
app.before_request(_before_request)
@@ -181,7 +180,13 @@ def _uninstrument(self, **kwargs):
181180
app = kwargs.get("app")
182181

183182
if app is None:
184-
pass
183+
flask.Flask = self._original_flask
185184

186185
else:
187-
pass
186+
app.wsgi_app = app._original_wsgi_app
187+
188+
# FIXME add support for other Flask blueprints that are not None
189+
app.before_request_funcs[None].remove(_before_request)
190+
app.teardown_request_funcs[None].remove(_teardown_request)
191+
192+
del app._original_wsgi_app

ext/opentelemetry-ext-flask/tests/test_flask_instrumentation.py renamed to ext/opentelemetry-ext-flask/tests/base_test.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
16-
17-
from flask import Flask, request
18-
from werkzeug.test import Client
19-
from werkzeug.wrappers import BaseResponse
15+
from flask import request
2016

2117
from opentelemetry import trace as trace_api
22-
from opentelemetry.ext.flask import FlaskInstrumentor
23-
from opentelemetry.test.wsgitestutil import WsgiTestBase
2418

2519

2620
def expected_attributes(override_attributes):
@@ -41,27 +35,7 @@ def expected_attributes(override_attributes):
4135
return default_attributes
4236

4337

44-
class TestFlaskInstrumentation(WsgiTestBase):
45-
def setUp(self):
46-
# No instrumentation code is here because it is present in the
47-
# conftest.py file next to this file.
48-
super().setUp()
49-
50-
self.app = Flask(__name__)
51-
52-
FlaskInstrumentor().instrument(app=self.app)
53-
54-
def hello_endpoint(helloid):
55-
if helloid == 500:
56-
raise ValueError(":-(")
57-
return "Hello: " + str(helloid)
58-
59-
self.app.route("/hello/<int:helloid>")(hello_endpoint)
60-
61-
self.client = Client(self.app, BaseResponse)
62-
63-
def tearDown(self):
64-
FlaskInstrumentor().uninstrument(app=self.app)
38+
class InstrumentationTest:
6539

6640
def test_only_strings_in_environ(self):
6741
"""
@@ -82,7 +56,7 @@ def assert_environ():
8256
self.client.get("/assert_environ")
8357
self.assertEqual(nonstring_keys, set())
8458

85-
def test_simple(self):
59+
def test_simple_uninstrument(self):
8660
expected_attrs = expected_attributes(
8761
{"http.target": "/hello/123", "http.route": "/hello/<int:helloid>"}
8862
)
@@ -131,7 +105,3 @@ def test_internal_error(self):
131105
self.assertEqual(span_list[0].name, "hello_endpoint")
132106
self.assertEqual(span_list[0].kind, trace_api.SpanKind.SERVER)
133107
self.assertEqual(span_list[0].attributes, expected_attrs)
134-
135-
136-
if __name__ == "__main__":
137-
unittest.main()
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from logging import NOTSET, WARNING, disable
16+
from unittest import main
17+
18+
# This is used instead of from flask import Flask, request because if not then
19+
# FlaskInstrumentor().instrument() would need to be called before importing
20+
# Flask. This is just an intrinsic limitation due the fact that we are testing
21+
# the instrumentor in a way that mimics how it would be called with the
22+
# opentelemetry-auto-instrumentation command. This does not mean that the
23+
# instrumentor should be used in this way in end user applications. For those
24+
# cases, FlaskInstrumentor.instrument(app=app) should be used.
25+
import flask
26+
from werkzeug.test import Client
27+
from werkzeug.wrappers import BaseResponse
28+
29+
from opentelemetry import trace as trace_api
30+
from opentelemetry.ext.flask import FlaskInstrumentor
31+
from opentelemetry.test.wsgitestutil import WsgiTestBase
32+
33+
from .base_test import InstrumentationTest, expected_attributes
34+
35+
36+
class TestAutomatic(WsgiTestBase, InstrumentationTest):
37+
def setUp(self):
38+
super().setUp()
39+
40+
FlaskInstrumentor().instrument()
41+
42+
self.app = flask.Flask(__name__)
43+
44+
def hello_endpoint(helloid):
45+
if helloid == 500:
46+
raise ValueError(":-(")
47+
return "Hello: " + str(helloid)
48+
49+
self.app.route("/hello/<int:helloid>")(hello_endpoint)
50+
51+
self.client = Client(self.app, BaseResponse)
52+
53+
def tearDown(self):
54+
disable(WARNING)
55+
FlaskInstrumentor().uninstrument()
56+
disable(NOTSET)
57+
58+
def test_uninstrument(self):
59+
expected_attrs = expected_attributes(
60+
{"http.target": "/hello/123", "http.route": "/hello/<int:helloid>"}
61+
)
62+
resp = self.client.get("/hello/123")
63+
self.assertEqual(200, resp.status_code)
64+
self.assertEqual([b"Hello: 123"], list(resp.response))
65+
span_list = self.memory_exporter.get_finished_spans()
66+
self.assertEqual(len(span_list), 1)
67+
self.assertEqual(span_list[0].name, "hello_endpoint")
68+
self.assertEqual(span_list[0].kind, trace_api.SpanKind.SERVER)
69+
self.assertEqual(span_list[0].attributes, expected_attrs)
70+
71+
FlaskInstrumentor().uninstrument()
72+
73+
expected_attrs = expected_attributes(
74+
{"http.target": "/hello/123", "http.route": "/hello/<int:helloid>"}
75+
)
76+
resp = self.client.get("/hello/123")
77+
self.assertEqual(200, resp.status_code)
78+
self.assertEqual([b"Hello: 123"], list(resp.response))
79+
span_list = self.memory_exporter.get_finished_spans()
80+
self.assertEqual(len(span_list), 1)
81+
self.assertEqual(span_list[0].name, "hello_endpoint")
82+
self.assertEqual(span_list[0].kind, trace_api.SpanKind.SERVER)
83+
self.assertEqual(span_list[0].attributes, expected_attrs)
84+
85+
86+
if __name__ == "__main__":
87+
main()
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from logging import NOTSET, WARNING, disable
16+
import unittest
17+
18+
from flask import Flask
19+
from werkzeug.test import Client
20+
from werkzeug.wrappers import BaseResponse
21+
22+
from opentelemetry import trace as trace_api
23+
from opentelemetry.ext.flask import FlaskInstrumentor
24+
from opentelemetry.test.wsgitestutil import WsgiTestBase
25+
26+
from .base_test import InstrumentationTest, expected_attributes
27+
28+
29+
class TestProgrammatic(WsgiTestBase, InstrumentationTest):
30+
def setUp(self):
31+
# No instrumentation code is here because it is present in the
32+
# conftest.py file next to this file.
33+
super().setUp()
34+
35+
self.app = Flask(__name__)
36+
37+
FlaskInstrumentor().instrument(app=self.app)
38+
39+
def hello_endpoint(helloid):
40+
if helloid == 500:
41+
raise ValueError(":-(")
42+
return "Hello: " + str(helloid)
43+
44+
self.app.route("/hello/<int:helloid>")(hello_endpoint)
45+
46+
self.client = Client(self.app, BaseResponse)
47+
48+
def tearDown(self):
49+
disable(WARNING)
50+
FlaskInstrumentor().uninstrument(app=self.app)
51+
disable(NOTSET)
52+
53+
def test_uninstrument(self):
54+
expected_attrs = expected_attributes(
55+
{"http.target": "/hello/123", "http.route": "/hello/<int:helloid>"}
56+
)
57+
resp = self.client.get("/hello/123")
58+
self.assertEqual(200, resp.status_code)
59+
self.assertEqual([b"Hello: 123"], list(resp.response))
60+
span_list = self.memory_exporter.get_finished_spans()
61+
self.assertEqual(len(span_list), 1)
62+
self.assertEqual(span_list[0].name, "hello_endpoint")
63+
self.assertEqual(span_list[0].kind, trace_api.SpanKind.SERVER)
64+
self.assertEqual(span_list[0].attributes, expected_attrs)
65+
66+
FlaskInstrumentor().uninstrument(app=self.app)
67+
68+
expected_attrs = expected_attributes(
69+
{"http.target": "/hello/123", "http.route": "/hello/<int:helloid>"}
70+
)
71+
resp = self.client.get("/hello/123")
72+
self.assertEqual(200, resp.status_code)
73+
self.assertEqual([b"Hello: 123"], list(resp.response))
74+
span_list = self.memory_exporter.get_finished_spans()
75+
self.assertEqual(len(span_list), 1)
76+
self.assertEqual(span_list[0].name, "hello_endpoint")
77+
self.assertEqual(span_list[0].kind, trace_api.SpanKind.SERVER)
78+
self.assertEqual(span_list[0].attributes, expected_attrs)
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main()

0 commit comments

Comments
 (0)