1313# limitations under the License.
1414
1515import logging
16+ from collections import defaultdict
1617from typing import (
1718 Collection ,
1819 Dict ,
@@ -768,9 +769,18 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
768769 )
769770
770771 @cached (iterable = True )
772+ async def get_mutual_event_relations_for_rel_type (
773+ self , event_id : str , relation_type : str
774+ ) -> Set [Tuple [str , str ]]:
775+ raise NotImplementedError ()
776+
777+ @cachedList (
778+ cached_method_name = "get_mutual_event_relations_for_rel_type" ,
779+ list_name = "relation_types" ,
780+ )
771781 async def get_mutual_event_relations (
772- self , event_id : str
773- ) -> Set [Tuple [str , str , str ]]:
782+ self , event_id : str , relation_types : Collection [ str ]
783+ ) -> Dict [ str , Set [Tuple [str , str ] ]]:
774784 """
775785 Fetch event metadata for events which related to the same event as the given event.
776786
@@ -780,20 +790,29 @@ async def get_mutual_event_relations(
780790 event_id: The event ID which is targeted by relations.
781791
782792 Returns:
783- A set of tuples of :
784- The relation type
785- The sender
786- The event type
793+ A dictionary of relation type to :
794+ A set of tuples of:
795+ The sender
796+ The event type
787797 """
788- sql = """
798+ rel_type_sql , rel_type_args = make_in_list_sql_clause (
799+ self .database_engine , "rel_type" , relation_types
800+ )
801+
802+ sql = f"""
789803 SELECT DISTINCT relation_type, sender, type FROM event_relations
790804 INNER JOIN events USING (event_id)
791- WHERE relates_to_id = ?
805+ WHERE relates_to_id = ? AND { rel_type_sql }
792806 """
793807
794- def _get_event_relations (txn : LoggingTransaction ) -> Set [Tuple [str , str , str ]]:
795- txn .execute (sql , (event_id ,))
796- return set (cast (List [Tuple [str , str , str ]], txn .fetchall ()))
808+ def _get_event_relations (
809+ txn : LoggingTransaction ,
810+ ) -> Dict [str , Set [Tuple [str , str ]]]:
811+ txn .execute (sql , [event_id ] + rel_type_args )
812+ result = defaultdict (set )
813+ for rel_type , sender , type in txn .fetchall ():
814+ result [rel_type ].add ((sender , type ))
815+ return result
797816
798817 return await self .db_pool .runInteraction (
799818 "get_event_relations" , _get_event_relations
0 commit comments