#!/usr/bin/env python3
"""RFC-001 Session Health Watchdog.

Runs `openclaw sessions --all-agents --json`, classifies sessions by lane, applies
context-usage thresholds, fires Telegram alerts, optionally prunes cron sessions,
and records the run in memory/heartbeat-state.json.
"""
from __future__ import annotations

import json
import subprocess
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

try:
    from zoneinfo import ZoneInfo
except ImportError:  # pragma: no cover
    from backports.zoneinfo import ZoneInfo  # type: ignore

WORKSPACE = Path(__file__).resolve().parents[2]
HEARTBEAT_FILE = WORKSPACE / "memory" / "heartbeat-state.json"
TELEGRAM_TARGET = "8032472383"
TELEGRAM_CHANNEL = "telegram"
TIMEZONE = ZoneInfo("America/Los_Angeles")

LANE_THRESHOLDS = {
    "main": {"warn": 0.30, "summarize": 0.40, "rollover": 0.50, "required": 0.60},
    "telegram": {"warn": 0.30, "summarize": 0.40, "rollover": 0.50, "required": 0.60},
    "worker": {"warn": 0.30, "summarize": 0.40, "rollover": 0.50, "required": 0.60},
    "ops-maintainer": {"warn": 0.30, "summarize": 0.40, "rollover": 0.50, "required": 0.60},
    "cron": {"warn": 0.30, "summarize": 0.40, "prune": 0.60},
}

@dataclass
class SessionUsage:
    key: str
    lane: str
    ratio: Optional[float]
    tokens_used: Optional[int]
    context_tokens: Optional[int]
    action: str  # info|none|warn|summarize|rollover|required|cron-prune

    @property
    def pct(self) -> Optional[float]:
        return None if self.ratio is None else self.ratio * 100


def run_command(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
    result = subprocess.run(cmd, capture_output=True, text=True)
    if check and result.returncode != 0:
        raise RuntimeError(
            f"Command {' '.join(cmd)} failed ({result.returncode}): {result.stderr.strip()}"
        )
    return result


def load_sessions() -> List[Dict[str, Any]]:
    proc = run_command(["openclaw", "sessions", "--all-agents", "--json"])
    try:
        payload = json.loads(proc.stdout)
    except json.JSONDecodeError as exc:  # pragma: no cover
        raise SystemExit(f"Failed to parse sessions JSON: {exc}")
    return payload.get("sessions", [])


def classify_lane(session_key: str) -> str:
    if session_key.startswith("agent:main:telegram"):
        return "telegram"
    if session_key.startswith("agent:main:cron"):
        return "cron"
    if session_key.startswith("agent:worker"):
        return "worker"
    if session_key.startswith("agent:ops-maintainer"):
        return "ops-maintainer"
    if session_key.startswith("agent:main"):
        return "main"
    return "other"


def evaluate_usage(session: Dict[str, Any]) -> SessionUsage:
    key = session.get("key", "unknown")
    lane = classify_lane(key)
    tokens_used = session.get("totalTokens")
    context_tokens = session.get("contextTokens")

    if not tokens_used or not context_tokens or context_tokens <= 0:
        return SessionUsage(key, lane, None, tokens_used, context_tokens, "info")

    ratio = tokens_used / context_tokens
    thresholds = LANE_THRESHOLDS.get(lane)

    if not thresholds:
        return SessionUsage(key, lane, ratio, tokens_used, context_tokens, "none")

    action = "none"
    if lane == "cron":
        if ratio >= thresholds["prune"]:
            action = "cron-prune"
        elif ratio >= thresholds["summarize"]:
            action = "summarize"
        elif ratio >= thresholds["warn"]:
            action = "warn"
    else:
        if ratio >= thresholds["required"]:
            action = "required"
        elif ratio >= thresholds["rollover"]:
            action = "rollover"
        elif ratio >= thresholds["summarize"]:
            action = "summarize"
        elif ratio >= thresholds["warn"]:
            action = "warn"

    return SessionUsage(key, lane, ratio, tokens_used, context_tokens, action)


def send_telegram(message: str) -> None:
    run_command(
        [
            "openclaw",
            "message",
            "send",
            "--channel",
            TELEGRAM_CHANNEL,
            "--target",
            TELEGRAM_TARGET,
            "--message",
            message,
        ]
    )


def update_heartbeat_log(entries: Iterable[SessionUsage], cleanup_triggered: bool) -> None:
    data: Dict[str, Any] = {}
    if HEARTBEAT_FILE.exists():
        try:
            data = json.loads(HEARTBEAT_FILE.read_text())
        except json.JSONDecodeError:
            data = {}
    timestamp = datetime.now(tz=TIMEZONE).isoformat()
    data.setdefault("lastChecks", {})
    data["lastSessionAudit"] = {
        "timestamp": timestamp,
        "cleanupTriggered": cleanup_triggered,
        "sessions": [
            {
                "key": usage.key,
                "lane": usage.lane,
                "usagePct": None if usage.pct is None else round(usage.pct, 2),
                "action": usage.action,
            }
            for usage in entries
        ],
    }
    HEARTBEAT_FILE.parent.mkdir(parents=True, exist_ok=True)
    HEARTBEAT_FILE.write_text(json.dumps(data, indent=2) + "\n")


def format_alert(usage: SessionUsage) -> str:
    pct = usage.pct or 0.0
    if usage.action == "summarize":
        tail = "summarize + prep rollover"
    elif usage.action == "rollover":
        tail = "rollover recommended"
    elif usage.action == "required":
        tail = "rollover required"
    else:
        tail = "threshold met"
    return f"Session {usage.key} ({usage.lane}) at {pct:.1f}% — {tail}."


def format_summary(entries: Iterable[SessionUsage]) -> str:
    buckets: Dict[str, int] = {}
    for entry in entries:
        buckets[entry.action] = buckets.get(entry.action, 0) + 1
    ordered = ", ".join(f"{name}:{count}" for name, count in sorted(buckets.items()))
    total = sum(buckets.values())
    return f"Session audit complete — total {total}; {ordered}" if ordered else "Session audit complete — no sessions found"


def main() -> None:
    sessions = load_sessions()
    usages = [evaluate_usage(session) for session in sessions]

    cleanup_needed = any(u.action == "cron-prune" for u in usages)
    alerts = [u for u in usages if u.action in {"summarize", "rollover", "required"}]

    for usage in alerts:
        send_telegram(format_alert(usage))

    if cleanup_needed:
        run_command(["openclaw", "sessions", "cleanup"])
        send_telegram("Cron sessions exceeded 60% context — openclaw sessions cleanup executed.")

    update_heartbeat_log(usages, cleanup_needed)
    print(format_summary(usages))


if __name__ == "__main__":
    try:
        main()
    except Exception as exc:  # pragma: no cover
        print(f"Session watchdog failed: {exc}", file=sys.stderr)
        sys.exit(1)
