@@ -17,33 +17,72 @@ def __init__(self, services):
1717 self .log = self .create_logger ('contact_tcp' )
1818 self .contact_svc = services .get ('contact_svc' )
1919 self .tcp_handler = TcpSessionHandler (services , self .log )
20+ self .server_task = None
21+ self .op_loop_task = None
22+ self .server = None
2023
2124 async def start (self ):
2225 loop = asyncio .get_event_loop ()
2326 tcp = self .get_config ('app.contact.tcp' )
24- loop .create_task (asyncio .start_server (self .tcp_handler .accept , * tcp .split (':' )))
25- loop .create_task (self .operation_loop ())
27+ self .server_task = loop .create_task (self .start_server (* tcp .split (':' )))
28+ self .op_loop_task = loop .create_task (self .operation_loop ())
29+
30+ async def stop (self ):
31+ tasks_to_stop = [t for t in (self .server_task , self .op_loop_task ) if t is not None ]
32+ for t in tasks_to_stop :
33+ if t :
34+ t .cancel ()
35+ if tasks_to_stop :
36+ _ = await asyncio .gather (* tasks_to_stop , return_exceptions = True )
37+
38+ async def start_server (self , host , port ):
39+ try :
40+ self .server = await asyncio .start_server (self .tcp_handler .accept , host , port )
41+ async with self .server :
42+ await self .server .serve_forever ()
43+ except asyncio .CancelledError :
44+ self .log .debug ('Canceling TCP contact server task.' )
45+ if self .server :
46+ self .log .debug ('Closing TCP contact server.' )
47+ self .server .close ()
48+ await self .server .wait_closed ()
49+ self .log .debug ('Closed TCP contact server.' )
50+ raise
2651
2752 async def operation_loop (self ):
28- while True :
29- await self .tcp_handler .refresh ()
30- for session in self .tcp_handler .sessions :
31- _ , instructions = await self .contact_svc .handle_heartbeat (paw = session .paw )
32- for instruction in instructions :
33- try :
34- self .log .debug ('TCP instruction: %s' % instruction .id )
35- status , _ , response , agent_reported_time = await self .tcp_handler .send (
36- session .id ,
37- self .decode_bytes (instruction .command ),
38- timeout = instruction .timeout
39- )
40- beacon = dict (paw = session .paw ,
41- results = [dict (id = instruction .id , output = self .encode_string (response ), status = status , agent_reported_time = agent_reported_time )])
42- await self .contact_svc .handle_heartbeat (** beacon )
43- await asyncio .sleep (instruction .sleep )
44- except Exception as e :
45- self .log .debug ('[-] operation exception: %s' % e )
46- await asyncio .sleep (20 )
53+ try :
54+ while True :
55+ await self .tcp_handler .refresh ()
56+ await self .handle_sessions ()
57+ await asyncio .sleep (20 )
58+ except asyncio .CancelledError :
59+ self .log .debug ('Canceling TCP contact operation loop task.' )
60+ for sess in self .tcp_handler .sessions :
61+ self .log .debug (f'Closing session { sess .id } .' )
62+ sess .writer .close ()
63+ await sess .writer .wait_closed ()
64+ self .log .debug ('Closed TCP contact sessions.' )
65+ raise
66+
67+ async def handle_sessions (self ):
68+ for session in self .tcp_handler .sessions :
69+ _ , instructions = await self .contact_svc .handle_heartbeat (paw = session .paw )
70+ for instruction in instructions :
71+ try :
72+ self .log .debug ('TCP instruction: %s' % instruction .id )
73+ status , _ , response , agent_reported_time = await self .tcp_handler .send (
74+ session .id ,
75+ self .decode_bytes (instruction .command ),
76+ timeout = instruction .timeout
77+ )
78+ beacon = dict (paw = session .paw ,
79+ results = [dict (id = instruction .id , output = self .encode_string (response ), status = status , agent_reported_time = agent_reported_time )])
80+ await self .contact_svc .handle_heartbeat (** beacon )
81+ await asyncio .sleep (instruction .sleep )
82+ except asyncio .CancelledError :
83+ raise
84+ except Exception as e :
85+ self .log .debug ('[-] operation exception: %s' % e )
4786
4887
4988class TcpSessionHandler (BaseWorld ):
@@ -68,6 +107,7 @@ async def refresh(self):
68107 index += 1
69108
70109 async def accept (self , reader , writer ):
110+ self .log .debug ('Accepting connection.' )
71111 try :
72112 profile = await self ._handshake (reader )
73113 except Exception as e :
0 commit comments