@@ -51,10 +51,10 @@ async def main():
5151
5252import re
5353from typing import Collection
54-
54+
5555import asyncpg
5656import wrapt
57-
57+
5858from opentelemetry import trace
5959from opentelemetry .instrumentation .asyncpg .package import _instruments
6060from opentelemetry .instrumentation .asyncpg .version import __version__
@@ -75,23 +75,23 @@ async def main():
7575)
7676from opentelemetry .trace import SpanKind
7777from opentelemetry .trace .status import Status , StatusCode
78-
79-
78+
79+
8080def _hydrate_span_from_args (connection , query , parameters ) -> dict :
8181 """Get network and database attributes from connection."""
8282 span_attributes = {DB_SYSTEM : DbSystemValues .POSTGRESQL .value }
83-
83+
8484 # connection contains _params attribute which is a namedtuple ConnectionParameters.
8585 # https://github.com/MagicStack/asyncpg/blob/master/asyncpg/connection.py#L68
86-
86+
8787 params = getattr (connection , "_params" , None )
8888 dbname = getattr (params , "database" , None )
8989 if dbname :
9090 span_attributes [DB_NAME ] = dbname
9191 user = getattr (params , "user" , None )
9292 if user :
9393 span_attributes [DB_USER ] = user
94-
94+
9595 # connection contains _addr attribute which is either a host/port tuple, or unix socket string
9696 # https://magicstack.github.io/asyncpg/current/_modules/asyncpg/connection.html
9797 addr = getattr (connection , "_addr" , None )
@@ -102,16 +102,16 @@ def _hydrate_span_from_args(connection, query, parameters) -> dict:
102102 elif isinstance (addr , str ):
103103 span_attributes [NET_PEER_NAME ] = addr
104104 span_attributes [NET_TRANSPORT ] = NetTransportValues .OTHER .value
105-
105+
106106 if query is not None :
107107 span_attributes [DB_STATEMENT ] = query
108-
108+
109109 if parameters is not None and len (parameters ) > 0 :
110110 span_attributes ["db.statement.parameters" ] = str (parameters )
111-
111+
112112 return span_attributes
113-
114-
113+
114+
115115class AsyncPGInstrumentor (BaseInstrumentor ):
116116 _leading_comment_remover = re .compile (r"^/\*.*?\*/" )
117117 _CLEANUP_QUERIES = frozenset ([
@@ -121,20 +121,20 @@ class AsyncPGInstrumentor(BaseInstrumentor):
121121 "RESET ALL" ,
122122 ])
123123 _tracer = None
124-
124+
125125 def _is_cleanup_query (self , query : str ) -> bool :
126126 if query is None :
127127 return False
128128 return any (q in query for q in self ._CLEANUP_QUERIES )
129-
129+
130130 def __init__ (self , capture_parameters = False , capture_connection_cleanup = True ):
131131 super ().__init__ ()
132132 self .capture_parameters = capture_parameters
133133 self .capture_connection_cleanup = capture_connection_cleanup
134-
134+
135135 def instrumentation_dependencies (self ) -> Collection [str ]:
136136 return _instruments
137-
137+
138138 def _instrument (self , ** kwargs ):
139139 tracer_provider = kwargs .get ("tracer_provider" )
140140 self ._tracer = trace .get_tracer (
@@ -143,7 +143,7 @@ def _instrument(self, **kwargs):
143143 tracer_provider ,
144144 schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
145145 )
146-
146+
147147 for method in [
148148 "Connection.execute" ,
149149 "Connection.executemany" ,
@@ -154,7 +154,7 @@ def _instrument(self, **kwargs):
154154 wrapt .wrap_function_wrapper (
155155 "asyncpg.connection" , method , self ._do_execute
156156 )
157-
157+
158158 for method in [
159159 "Cursor.fetch" ,
160160 "Cursor.forward" ,
@@ -164,7 +164,7 @@ def _instrument(self, **kwargs):
164164 wrapt .wrap_function_wrapper (
165165 "asyncpg.cursor" , method , self ._do_cursor_execute
166166 )
167-
167+
168168 def _uninstrument (self , ** __ ):
169169 for cls , methods in [
170170 (
@@ -176,30 +176,29 @@ def _uninstrument(self, **__):
176176 ]:
177177 for method_name in methods :
178178 unwrap (cls , method_name )
179-
179+
180180 async def _do_execute (self , func , instance , args , kwargs ):
181181 exception = None
182182 params = getattr (instance , "_params" , None )
183183 name = (
184184 args [0 ] if args [0 ] else getattr (params , "database" , "postgresql" )
185185 )
186-
186+
187187 try :
188188 # Strip leading comments so we get the operation name.
189189 name = self ._leading_comment_remover .sub ("" , name ).split ()[0 ]
190190 except IndexError :
191191 name = ""
192-
192+
193193 # Hydrate attributes before span creation to enable filtering
194194 span_attributes = _hydrate_span_from_args (
195195 instance ,
196196 args [0 ],
197197 args [1 :] if self .capture_parameters else None ,
198198 )
199-
200199 if not self .capture_connection_cleanup and self ._is_cleanup_query (args [0 ]):
201200 return await func (* args , ** kwargs )
202-
201+
203202 with self ._tracer .start_as_current_span (
204203 name , kind = SpanKind .CLIENT , attributes = span_attributes
205204 ) as span :
@@ -211,9 +210,9 @@ async def _do_execute(self, func, instance, args, kwargs):
211210 finally :
212211 if span .is_recording () and exception is not None :
213212 span .set_status (Status (StatusCode .ERROR ))
214-
213+
215214 return result
216-
215+
217216 async def _do_cursor_execute (self , func , instance , args , kwargs ):
218217 """Wrap cursor based functions. For every call this will generate a new span."""
219218 exception = None
@@ -223,20 +222,20 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
223222 if instance ._query
224223 else getattr (params , "database" , "postgresql" )
225224 )
226-
225+
227226 try :
228227 # Strip leading comments so we get the operation name.
229228 name = self ._leading_comment_remover .sub ("" , name ).split ()[0 ]
230229 except IndexError :
231230 name = ""
232-
231+
233232 # Hydrate attributes before span creation to enable filtering
234233 span_attributes = _hydrate_span_from_args (
235234 instance ._connection ,
236235 instance ._query ,
237236 instance ._args if self .capture_parameters else None ,
238237 )
239-
238+
240239 stop = False
241240 with self ._tracer .start_as_current_span (
242241 f"CURSOR: { name } " ,
@@ -254,8 +253,7 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
254253 finally :
255254 if span .is_recording () and exception is not None :
256255 span .set_status (Status (StatusCode .ERROR ))
257-
256+
258257 if not stop :
259258 return result
260259 raise StopAsyncIteration
261-
0 commit comments