@@ -51,10 +51,10 @@ async def main():
5151
5252import re
5353from typing import Collection
54-
54+
5555import asyncpg
5656import wrapt
57-
57+
5858from opentelemetry import trace
5959from opentelemetry .instrumentation .asyncpg .package import _instruments
6060from opentelemetry .instrumentation .asyncpg .version import __version__
@@ -75,23 +75,23 @@ async def main():
7575)
7676from opentelemetry .trace import SpanKind
7777from opentelemetry .trace .status import Status , StatusCode
78-
79-
78+
79+
8080def _hydrate_span_from_args (connection , query , parameters ) -> dict :
8181 """Get network and database attributes from connection."""
8282 span_attributes = {DB_SYSTEM : DbSystemValues .POSTGRESQL .value }
83-
83+
8484 # connection contains _params attribute which is a namedtuple ConnectionParameters.
8585 # https://github.com/MagicStack/asyncpg/blob/master/asyncpg/connection.py#L68
86-
86+
8787 params = getattr (connection , "_params" , None )
8888 dbname = getattr (params , "database" , None )
8989 if dbname :
9090 span_attributes [DB_NAME ] = dbname
9191 user = getattr (params , "user" , None )
9292 if user :
9393 span_attributes [DB_USER ] = user
94-
94+
9595 # connection contains _addr attribute which is either a host/port tuple, or unix socket string
9696 # https://magicstack.github.io/asyncpg/current/_modules/asyncpg/connection.html
9797 addr = getattr (connection , "_addr" , None )
@@ -102,27 +102,39 @@ def _hydrate_span_from_args(connection, query, parameters) -> dict:
102102 elif isinstance (addr , str ):
103103 span_attributes [NET_PEER_NAME ] = addr
104104 span_attributes [NET_TRANSPORT ] = NetTransportValues .OTHER .value
105-
105+
106106 if query is not None :
107107 span_attributes [DB_STATEMENT ] = query
108-
108+
109109 if parameters is not None and len (parameters ) > 0 :
110110 span_attributes ["db.statement.parameters" ] = str (parameters )
111-
111+
112112 return span_attributes
113-
114-
113+
114+
115115class AsyncPGInstrumentor (BaseInstrumentor ):
116116 _leading_comment_remover = re .compile (r"^/\*.*?\*/" )
117+ _CLEANUP_QUERIES = frozenset ([
118+ "SELECT pg_advisory_unlock_all()" ,
119+ "CLOSE ALL" ,
120+ "UNLISTEN *" ,
121+ "RESET ALL" ,
122+ ])
117123 _tracer = None
118-
119- def __init__ (self , capture_parameters = False ):
124+
125+ def _is_cleanup_query (self , query : str ) -> bool :
126+ if query is None :
127+ return False
128+ return any (q in query for q in self ._CLEANUP_QUERIES )
129+
130+ def __init__ (self , capture_parameters = False , capture_connection_cleanup = True ):
120131 super ().__init__ ()
121132 self .capture_parameters = capture_parameters
122-
133+ self .capture_connection_cleanup = capture_connection_cleanup
134+
123135 def instrumentation_dependencies (self ) -> Collection [str ]:
124136 return _instruments
125-
137+
126138 def _instrument (self , ** kwargs ):
127139 tracer_provider = kwargs .get ("tracer_provider" )
128140 self ._tracer = trace .get_tracer (
@@ -131,7 +143,7 @@ def _instrument(self, **kwargs):
131143 tracer_provider ,
132144 schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
133145 )
134-
146+
135147 for method in [
136148 "Connection.execute" ,
137149 "Connection.executemany" ,
@@ -142,7 +154,7 @@ def _instrument(self, **kwargs):
142154 wrapt .wrap_function_wrapper (
143155 "asyncpg.connection" , method , self ._do_execute
144156 )
145-
157+
146158 for method in [
147159 "Cursor.fetch" ,
148160 "Cursor.forward" ,
@@ -152,7 +164,7 @@ def _instrument(self, **kwargs):
152164 wrapt .wrap_function_wrapper (
153165 "asyncpg.cursor" , method , self ._do_cursor_execute
154166 )
155-
167+
156168 def _uninstrument (self , ** __ ):
157169 for cls , methods in [
158170 (
@@ -164,27 +176,30 @@ def _uninstrument(self, **__):
164176 ]:
165177 for method_name in methods :
166178 unwrap (cls , method_name )
167-
179+
168180 async def _do_execute (self , func , instance , args , kwargs ):
169181 exception = None
170182 params = getattr (instance , "_params" , None )
171183 name = (
172184 args [0 ] if args [0 ] else getattr (params , "database" , "postgresql" )
173185 )
174-
186+
175187 try :
176188 # Strip leading comments so we get the operation name.
177189 name = self ._leading_comment_remover .sub ("" , name ).split ()[0 ]
178190 except IndexError :
179191 name = ""
180-
192+
181193 # Hydrate attributes before span creation to enable filtering
182194 span_attributes = _hydrate_span_from_args (
183195 instance ,
184196 args [0 ],
185197 args [1 :] if self .capture_parameters else None ,
186198 )
187-
199+
200+ if not self .capture_connection_cleanup and self ._is_cleanup_query (args [0 ]):
201+ return await func (* args , ** kwargs )
202+
188203 with self ._tracer .start_as_current_span (
189204 name , kind = SpanKind .CLIENT , attributes = span_attributes
190205 ) as span :
@@ -196,9 +211,9 @@ async def _do_execute(self, func, instance, args, kwargs):
196211 finally :
197212 if span .is_recording () and exception is not None :
198213 span .set_status (Status (StatusCode .ERROR ))
199-
214+
200215 return result
201-
216+
202217 async def _do_cursor_execute (self , func , instance , args , kwargs ):
203218 """Wrap cursor based functions. For every call this will generate a new span."""
204219 exception = None
@@ -208,20 +223,20 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
208223 if instance ._query
209224 else getattr (params , "database" , "postgresql" )
210225 )
211-
226+
212227 try :
213228 # Strip leading comments so we get the operation name.
214229 name = self ._leading_comment_remover .sub ("" , name ).split ()[0 ]
215230 except IndexError :
216231 name = ""
217-
232+
218233 # Hydrate attributes before span creation to enable filtering
219234 span_attributes = _hydrate_span_from_args (
220235 instance ._connection ,
221236 instance ._query ,
222237 instance ._args if self .capture_parameters else None ,
223238 )
224-
239+
225240 stop = False
226241 with self ._tracer .start_as_current_span (
227242 f"CURSOR: { name } " ,
@@ -239,7 +254,8 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
239254 finally :
240255 if span .is_recording () and exception is not None :
241256 span .set_status (Status (StatusCode .ERROR ))
242-
257+
243258 if not stop :
244259 return result
245260 raise StopAsyncIteration
261+
0 commit comments