#!/usr/bin/env python3
"""
Legitimate Client - Demonstrates SSE stream vulnerability
"""
import json
import requests
import threading
import time
import sys

SERVER_URL = "http://localhost:9393"

# Follows MCP spec to initalize and get the session ID
def initialize_session():
    payload = {
        "jsonrpc": "2.0",
        "method": "initialize",
        "id": "init-1",
        "params": {
            "protocolVersion": "2024-11-05",
            "capabilities": {"logging": {}},
            "clientInfo": {"name": "legitimate-client", "version": "1.0.0"}
        }
    }
    
    response = requests.post(SERVER_URL, json=payload, headers={
        "Content-Type": "application/json",
        "Accept": "application/json, text/event-stream"
    })
    
    session_id = response.headers.get("Mcp-Session-Id")
    print(f"\nConnected - Session ID: {session_id}\n")
    print(f"To hijack: python3 attacker_client_ruby_server.py {session_id}\n")
    return session_id

# Listen for the response
def sse_listener(session_id, stop_event):
    response = requests.get(SERVER_URL, headers={
        "Mcp-Session-Id": session_id,
        "Accept": "text/event-stream",
        "Cache-Control": "no-cache"
    }, stream=True, timeout=None)
    
    print("SSE stream connected\n")
    
    try:
        for line in response.iter_lines():
            if stop_event.is_set():
                break
                
            if line and line.decode('utf-8').startswith("data: "):
                try:
                    message = json.loads(line.decode('utf-8')[6:])
                    
                    if "result" in message and "content" in message.get("result", {}):
                        content = message["result"]["content"]
                        if content and isinstance(content, list):
                            print(f"[LEGITIMATE] {content[0].get('text', '')}")
                
                except json.JSONDecodeError:
                    pass
    
    except requests.exceptions.ChunkedEncodingError:
        print("\nSTREAM HIJACKED - Connection lost!\n")
    except Exception as e:
        if not stop_event.is_set():
            print(f"\nConnection lost: {e}\n")

# Use the sample servers notification tool call
def send_tool_calls(session_id, stop_event):
    call_num = 0
    # Simulating sample tool call data from legitimate client
    messages = [
        "Transaction: $5,000 transfer",
        "Security alert: New device login",
        "Password reset code: ABC-{}XYZ",
        "Payment: $2,500 completed",
        "Bank transfer: $10,000",
        "2FA code: {}"
    ]
    
    while not stop_event.is_set():
        call_num += 1
        message = messages[call_num % len(messages)].format(call_num, 123400 + call_num)
        
        payload = {
            "jsonrpc": "2.0",
            "method": "tools/call",
            "id": f"call-{call_num}",
            "params": {
                "name": "notification_tool",
                "arguments": {"message": message, "delay": 0}
            }
        }
        
        requests.post(SERVER_URL, json=payload, headers={
            "Content-Type": "application/json",
            "Accept": "application/json, text/event-stream",
            "Mcp-Session-Id": session_id
        }, timeout=5)
        
        time.sleep(3)

# Connects to the server and listens to events
def main():
    print("\n" + "="*60)
    print("LEGITIMATE CLIENT")
    print("="*60)
    
    session_id = initialize_session()
    input("Press Enter to start...\n")
    
    stop_event = threading.Event()
    
    sse_thread = threading.Thread(target=sse_listener, args=(session_id, stop_event), daemon=True)
    sse_thread.start()
    time.sleep(2)
    
    print("Sending tool calls every 3 seconds...\n")
    
    try:
        send_tool_calls(session_id, stop_event)
    except KeyboardInterrupt:
        stop_event.set()
        print("\n Exiting...\n")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        sys.exit(0)
