1414import calendar
1515import logging
1616import time
17- from typing import TYPE_CHECKING , Dict
17+ from typing import TYPE_CHECKING , Dict , List , Tuple , cast
1818
1919from synapse .metrics import GaugeBucketCollector
2020from synapse .metrics .background_process_metrics import wrap_as_background_process
2121from synapse .storage ._base import SQLBaseStore
22- from synapse .storage .database import DatabasePool , LoggingDatabaseConnection
22+ from synapse .storage .database import (
23+ DatabasePool ,
24+ LoggingDatabaseConnection ,
25+ LoggingTransaction ,
26+ )
2327from synapse .storage .databases .main .event_push_actions import (
2428 EventPushActionsWorkerStore ,
2529)
26- from synapse .storage .types import Cursor
2730
2831if TYPE_CHECKING :
2932 from synapse .server import HomeServer
@@ -73,7 +76,7 @@ def __init__(
7376
7477 @wrap_as_background_process ("read_forward_extremities" )
7578 async def _read_forward_extremities (self ) -> None :
76- def fetch (txn ) :
79+ def fetch (txn : LoggingTransaction ) -> List [ Tuple [ int , int ]] :
7780 txn .execute (
7881 """
7982 SELECT t1.c, t2.c
@@ -86,7 +89,7 @@ def fetch(txn):
8689 ) t2 ON t1.room_id = t2.room_id
8790 """
8891 )
89- return txn .fetchall ()
92+ return cast ( List [ Tuple [ int , int ]], txn .fetchall () )
9093
9194 res = await self .db_pool .runInteraction ("read_forward_extremities" , fetch )
9295
@@ -104,20 +107,20 @@ async def count_daily_e2ee_messages(self) -> int:
104107 call to this function, it will return None.
105108 """
106109
107- def _count_messages (txn ) :
110+ def _count_messages (txn : LoggingTransaction ) -> int :
108111 sql = """
109112 SELECT COUNT(*) FROM events
110113 WHERE type = 'm.room.encrypted'
111114 AND stream_ordering > ?
112115 """
113116 txn .execute (sql , (self .stream_ordering_day_ago ,))
114- (count ,) = txn .fetchone ()
117+ (count ,) = cast ( Tuple [ int ], txn .fetchone () )
115118 return count
116119
117120 return await self .db_pool .runInteraction ("count_e2ee_messages" , _count_messages )
118121
119122 async def count_daily_sent_e2ee_messages (self ) -> int :
120- def _count_messages (txn ) :
123+ def _count_messages (txn : LoggingTransaction ) -> int :
121124 # This is good enough as if you have silly characters in your own
122125 # hostname then that's your own fault.
123126 like_clause = "%:" + self .hs .hostname
@@ -130,22 +133,22 @@ def _count_messages(txn):
130133 """
131134
132135 txn .execute (sql , (like_clause , self .stream_ordering_day_ago ))
133- (count ,) = txn .fetchone ()
136+ (count ,) = cast ( Tuple [ int ], txn .fetchone () )
134137 return count
135138
136139 return await self .db_pool .runInteraction (
137140 "count_daily_sent_e2ee_messages" , _count_messages
138141 )
139142
140143 async def count_daily_active_e2ee_rooms (self ) -> int :
141- def _count (txn ) :
144+ def _count (txn : LoggingTransaction ) -> int :
142145 sql = """
143146 SELECT COUNT(DISTINCT room_id) FROM events
144147 WHERE type = 'm.room.encrypted'
145148 AND stream_ordering > ?
146149 """
147150 txn .execute (sql , (self .stream_ordering_day_ago ,))
148- (count ,) = txn .fetchone ()
151+ (count ,) = cast ( Tuple [ int ], txn .fetchone () )
149152 return count
150153
151154 return await self .db_pool .runInteraction (
@@ -160,20 +163,20 @@ async def count_daily_messages(self) -> int:
160163 call to this function, it will return None.
161164 """
162165
163- def _count_messages (txn ) :
166+ def _count_messages (txn : LoggingTransaction ) -> int :
164167 sql = """
165168 SELECT COUNT(*) FROM events
166169 WHERE type = 'm.room.message'
167170 AND stream_ordering > ?
168171 """
169172 txn .execute (sql , (self .stream_ordering_day_ago ,))
170- (count ,) = txn .fetchone ()
173+ (count ,) = cast ( Tuple [ int ], txn .fetchone () )
171174 return count
172175
173176 return await self .db_pool .runInteraction ("count_messages" , _count_messages )
174177
175178 async def count_daily_sent_messages (self ) -> int :
176- def _count_messages (txn ) :
179+ def _count_messages (txn : LoggingTransaction ) -> int :
177180 # This is good enough as if you have silly characters in your own
178181 # hostname then that's your own fault.
179182 like_clause = "%:" + self .hs .hostname
@@ -186,22 +189,22 @@ def _count_messages(txn):
186189 """
187190
188191 txn .execute (sql , (like_clause , self .stream_ordering_day_ago ))
189- (count ,) = txn .fetchone ()
192+ (count ,) = cast ( Tuple [ int ], txn .fetchone () )
190193 return count
191194
192195 return await self .db_pool .runInteraction (
193196 "count_daily_sent_messages" , _count_messages
194197 )
195198
196199 async def count_daily_active_rooms (self ) -> int :
197- def _count (txn ) :
200+ def _count (txn : LoggingTransaction ) -> int :
198201 sql = """
199202 SELECT COUNT(DISTINCT room_id) FROM events
200203 WHERE type = 'm.room.message'
201204 AND stream_ordering > ?
202205 """
203206 txn .execute (sql , (self .stream_ordering_day_ago ,))
204- (count ,) = txn .fetchone ()
207+ (count ,) = cast ( Tuple [ int ], txn .fetchone () )
205208 return count
206209
207210 return await self .db_pool .runInteraction ("count_daily_active_rooms" , _count )
@@ -227,7 +230,7 @@ async def count_monthly_users(self) -> int:
227230 "count_monthly_users" , self ._count_users , thirty_days_ago
228231 )
229232
230- def _count_users (self , txn : Cursor , time_from : int ) -> int :
233+ def _count_users (self , txn : LoggingTransaction , time_from : int ) -> int :
231234 """
232235 Returns number of users seen in the past time_from period
233236 """
@@ -242,7 +245,7 @@ def _count_users(self, txn: Cursor, time_from: int) -> int:
242245 # Mypy knows that fetchone() might return None if there are no rows.
243246 # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
244247 # returns exactly one row.
245- (count ,) = txn .fetchone () # type: ignore[misc]
248+ (count ,) = cast ( Tuple [ int ], txn .fetchone ())
246249 return count
247250
248251 async def count_r30_users (self ) -> Dict [str , int ]:
@@ -256,7 +259,7 @@ async def count_r30_users(self) -> Dict[str, int]:
256259 A mapping of counts globally as well as broken out by platform.
257260 """
258261
259- def _count_r30_users (txn ) :
262+ def _count_r30_users (txn : LoggingTransaction ) -> Dict [ str , int ] :
260263 thirty_days_in_secs = 86400 * 30
261264 now = int (self ._clock .time ())
262265 thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -321,7 +324,7 @@ def _count_r30_users(txn):
321324
322325 txn .execute (sql , (thirty_days_ago_in_secs , thirty_days_ago_in_secs ))
323326
324- (count ,) = txn .fetchone ()
327+ (count ,) = cast ( Tuple [ int ], txn .fetchone () )
325328 results ["all" ] = count
326329
327330 return results
@@ -348,7 +351,7 @@ async def count_r30v2_users(self) -> Dict[str, int]:
348351 - "web" (any web application -- it's not possible to distinguish Element Web here)
349352 """
350353
351- def _count_r30v2_users (txn ) :
354+ def _count_r30v2_users (txn : LoggingTransaction ) -> Dict [ str , int ] :
352355 thirty_days_in_secs = 86400 * 30
353356 now = int (self ._clock .time ())
354357 sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -445,11 +448,8 @@ def _count_r30v2_users(txn):
445448 thirty_days_in_secs * 1000 ,
446449 ),
447450 )
448- row = txn .fetchone ()
449- if row is None :
450- results ["all" ] = 0
451- else :
452- results ["all" ] = row [0 ]
451+ (count ,) = cast (Tuple [int ], txn .fetchone ())
452+ results ["all" ] = count
453453
454454 return results
455455
@@ -471,7 +471,7 @@ async def generate_user_daily_visits(self) -> None:
471471 Generates daily visit data for use in cohort/ retention analysis
472472 """
473473
474- def _generate_user_daily_visits (txn ) :
474+ def _generate_user_daily_visits (txn : LoggingTransaction ) -> None :
475475 logger .info ("Calling _generate_user_daily_visits" )
476476 today_start = self ._get_start_of_day ()
477477 a_day_in_milliseconds = 24 * 60 * 60 * 1000
0 commit comments