1313# limitations under the License.
1414
1515import logging
16- from typing import Mapping , Optional
16+ from typing import TYPE_CHECKING , Any , Mapping , NoReturn , Optional , Tuple , cast
1717
1818from synapse .storage .engines ._base import (
1919 BaseDatabaseEngine ,
2020 IncorrectDatabaseSetup ,
2121 IsolationLevel ,
2222)
23- from synapse .storage .types import Connection
23+ from synapse .storage .types import Cursor
24+
25+ if TYPE_CHECKING :
26+ import psycopg2 # noqa: F401
27+
28+ from synapse .storage .database import LoggingDatabaseConnection
29+
2430
2531logger = logging .getLogger (__name__ )
2632
2733
28- class PostgresEngine (BaseDatabaseEngine ):
29- def __init__ (self , database_module , database_config ):
30- super ().__init__ (database_module , database_config )
31- self .module .extensions .register_type (self .module .extensions .UNICODE )
34+ class PostgresEngine (BaseDatabaseEngine ["psycopg2.connection" ]):
35+ def __init__ (self , database_config : Mapping [str , Any ]):
36+ import psycopg2 .extensions
37+
38+ super ().__init__ (psycopg2 , database_config )
39+ psycopg2 .extensions .register_type (psycopg2 .extensions .UNICODE )
3240
3341 # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
3442 # actually want to use bytes than wrap it in `bytearray`.
35- def _disable_bytes_adapter (_ ) :
43+ def _disable_bytes_adapter (_ : bytes ) -> NoReturn :
3644 raise Exception ("Passing bytes to DB is disabled." )
3745
38- self . module .extensions .register_adapter (bytes , _disable_bytes_adapter )
39- self .synchronous_commit = database_config .get ("synchronous_commit" , True )
40- self ._version = None # unknown as yet
46+ psycopg2 .extensions .register_adapter (bytes , _disable_bytes_adapter )
47+ self .synchronous_commit : bool = database_config .get ("synchronous_commit" , True )
48+ self ._version : Optional [ int ] = None # unknown as yet
4149
4250 self .isolation_level_map : Mapping [int , int ] = {
43- IsolationLevel .READ_COMMITTED : self . module .extensions .ISOLATION_LEVEL_READ_COMMITTED ,
44- IsolationLevel .REPEATABLE_READ : self . module .extensions .ISOLATION_LEVEL_REPEATABLE_READ ,
45- IsolationLevel .SERIALIZABLE : self . module .extensions .ISOLATION_LEVEL_SERIALIZABLE ,
51+ IsolationLevel .READ_COMMITTED : psycopg2 .extensions .ISOLATION_LEVEL_READ_COMMITTED ,
52+ IsolationLevel .REPEATABLE_READ : psycopg2 .extensions .ISOLATION_LEVEL_REPEATABLE_READ ,
53+ IsolationLevel .SERIALIZABLE : psycopg2 .extensions .ISOLATION_LEVEL_SERIALIZABLE ,
4654 }
4755 self .default_isolation_level = (
48- self . module .extensions .ISOLATION_LEVEL_REPEATABLE_READ
56+ psycopg2 .extensions .ISOLATION_LEVEL_REPEATABLE_READ
4957 )
5058 self .config = database_config
5159
5260 @property
5361 def single_threaded (self ) -> bool :
5462 return False
5563
56- def get_db_locale (self , txn ) :
64+ def get_db_locale (self , txn : Cursor ) -> Tuple [ str , str ] :
5765 txn .execute (
5866 "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
5967 )
60- collation , ctype = txn .fetchone ()
68+ collation , ctype = cast ( Tuple [ str , str ], txn .fetchone () )
6169 return collation , ctype
6270
63- def check_database (self , db_conn , allow_outdated_version : bool = False ):
71+ def check_database (
72+ self , db_conn : "psycopg2.connection" , allow_outdated_version : bool = False
73+ ) -> None :
6474 # Get the version of PostgreSQL that we're using. As per the psycopg2
6575 # docs: The number is formed by converting the major, minor, and
6676 # revision numbers into two-decimal-digit numbers and appending them
6777 # together. For example, version 8.1.5 will be returned as 80105
68- self ._version = db_conn .server_version
78+ self ._version = cast ( int , db_conn .server_version )
6979 allow_unsafe_locale = self .config .get ("allow_unsafe_locale" , False )
7080
7181 # Are we on a supported PostgreSQL version?
@@ -108,7 +118,7 @@ def check_database(self, db_conn, allow_outdated_version: bool = False):
108118 ctype ,
109119 )
110120
111- def check_new_database (self , txn ) :
121+ def check_new_database (self , txn : Cursor ) -> None :
112122 """Gets called when setting up a brand new database. This allows us to
113123 apply stricter checks on new databases versus existing database.
114124 """
@@ -129,10 +139,10 @@ def check_new_database(self, txn):
129139 "See docs/postgres.md for more information." % ("\n " .join (errors ))
130140 )
131141
132- def convert_param_style (self , sql ) :
142+ def convert_param_style (self , sql : str ) -> str :
133143 return sql .replace ("?" , "%s" )
134144
135- def on_new_connection (self , db_conn ) :
145+ def on_new_connection (self , db_conn : "LoggingDatabaseConnection" ) -> None :
136146 db_conn .set_isolation_level (self .default_isolation_level )
137147
138148 # Set the bytea output to escape, vs the default of hex
@@ -149,14 +159,14 @@ def on_new_connection(self, db_conn):
149159 db_conn .commit ()
150160
151161 @property
152- def can_native_upsert (self ):
162+ def can_native_upsert (self ) -> bool :
153163 """
154164 Can we use native UPSERTs?
155165 """
156166 return True
157167
158168 @property
159- def supports_using_any_list (self ):
169+ def supports_using_any_list (self ) -> bool :
160170 """Do we support using `a = ANY(?)` and passing a list"""
161171 return True
162172
@@ -165,27 +175,25 @@ def supports_returning(self) -> bool:
165175 """Do we support the `RETURNING` clause in insert/update/delete?"""
166176 return True
167177
168- def is_deadlock (self , error ):
169- if isinstance (error , self .module .DatabaseError ):
178+ def is_deadlock (self , error : Exception ) -> bool :
179+ import psycopg2 .extensions
180+
181+ if isinstance (error , psycopg2 .DatabaseError ):
170182 # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
171183 # "40001" serialization_failure
172184 # "40P01" deadlock_detected
173185 return error .pgcode in ["40001" , "40P01" ]
174186 return False
175187
176- def is_connection_closed (self , conn ) :
188+ def is_connection_closed (self , conn : "psycopg2.connection" ) -> bool :
177189 return bool (conn .closed )
178190
179- def lock_table (self , txn , table ) :
191+ def lock_table (self , txn : Cursor , table : str ) -> None :
180192 txn .execute ("LOCK TABLE %s in EXCLUSIVE MODE" % (table ,))
181193
182194 @property
183- def server_version (self ):
184- """Returns a string giving the server version. For example: '8.1.5'
185-
186- Returns:
187- string
188- """
195+ def server_version (self ) -> str :
196+ """Returns a string giving the server version. For example: '8.1.5'."""
189197 # note that this is a bit of a hack because it relies on check_database
190198 # having been called. Still, that should be a safe bet here.
191199 numver = self ._version
@@ -197,17 +205,21 @@ def server_version(self):
197205 else :
198206 return "%i.%i.%i" % (numver / 10000 , (numver % 10000 ) / 100 , numver % 100 )
199207
200- def in_transaction (self , conn : Connection ) -> bool :
201- return conn .status != self .module .extensions .STATUS_READY # type: ignore
208+ def in_transaction (self , conn : "psycopg2.connection" ) -> bool :
209+ import psycopg2 .extensions
210+
211+ return conn .status != psycopg2 .extensions .STATUS_READY
202212
203- def attempt_to_set_autocommit (self , conn : Connection , autocommit : bool ):
204- return conn .set_session (autocommit = autocommit ) # type: ignore
213+ def attempt_to_set_autocommit (
214+ self , conn : "psycopg2.connection" , autocommit : bool
215+ ) -> None :
216+ return conn .set_session (autocommit = autocommit )
205217
206218 def attempt_to_set_isolation_level (
207- self , conn : Connection , isolation_level : Optional [int ]
208- ):
219+ self , conn : "psycopg2.connection" , isolation_level : Optional [int ]
220+ ) -> None :
209221 if isolation_level is None :
210222 isolation_level = self .default_isolation_level
211223 else :
212224 isolation_level = self .isolation_level_map [isolation_level ]
213- return conn .set_isolation_level (isolation_level ) # type: ignore
225+ return conn .set_isolation_level (isolation_level )
0 commit comments