#!/usr/bin/env python3
"""cost-ceiling-actuator.py — THR-109 Inc 3c: B2 ACTUATING cost ceiling.

The RFC-002 §1 circuit-breaker only ALERTS on 5h-bucket cost (it has no kill path).
This ADDS the actuation: on a hard %/rate breach of the 5-hour agentic meter, it
TERMINATES the offending running session(s) on THIS host — preserving the session key
per the Incident Protocol (THR-107: status->done, abortedLastRun=true, endedAt=now) —
and fires a fail-loud alert.

SWAPPABLE BY DESIGN: detection (read_usage / evaluate_breach) and action
(terminate_running_sessions) are separate. When a native
`agents.guardrails.maxCostPerSession` ships (tracked as THR-110 upgrade), delete this
script + its plist; nothing else depends on it.

Data source: oauth-usage-history.jsonl (the same 5h-meter §1 collects; day_pct =
the 5h bucket's usedPercent per oauth_usage_collect.py:92).

SAFETY: default DRY-RUN — reports what it WOULD do, no termination, no real alert.
Pass --apply to actually terminate + alert. Per-host only (terminates sessions that
are `running` in the local store); a shared-bucket drain from another consumer is
reported but cannot be terminated from here.
"""

import argparse
import json
import os
import shutil
import subprocess
import sys
import time
from datetime import datetime, timezone

USAGE_JSONL   = "/Users/openclaw/.openclaw/workspace/state/oauth-usage-history.jsonl"
SESSION_STORE = "/Users/openclaw/.openclaw/agents/main/sessions/sessions.json"
SENTINEL      = "/Users/openclaw/.openclaw/workspace/state/cost-actuator-last-fire.flag"
LOG_PATH      = "/tmp/openclaw/cost-ceiling-actuator.log"
OPENCLAW_BIN  = shutil.which("openclaw") or "/opt/homebrew/bin/openclaw"
TELEGRAM_TARGET = "8032472383"

# --- thresholds (overridable for tests) ---
HARD_CEILING_PCT = 95.0   # day_pct >= this -> hard breach (near-exhaustion; well past §1 CRITICAL 80)
RATE_WINDOW      = 4       # evaluable rows (~20 min @ 5-min cadence) for the slope check
RATE_DELTA_PCT   = 30.0    # >= this rise across the window -> rate breach. THR-108's steepest
                           # 20-min rise was ~34-38pts, so 30 catches that signature; secondary
                           # to the hard ceiling (the §1 breaker already ALERTS on slope).
RATE_FLOOR_PCT   = 70.0    # ...only when already >= this consumed (near-top steep climb only;
                           # avoids terminating a mid-level burst that may still level off).
STALE_SECS       = 1800    # latest row older than this -> do not act (cannot trust)
COOLDOWN_SECS    = 3600    # do not re-fire within this window

_LOG = LOG_PATH

def log(msg):
    os.makedirs(os.path.dirname(_LOG), exist_ok=True)
    ts = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
    with open(_LOG, "a") as f:
        f.write(f"[{ts}] {msg}\n")

def tail_jsonl(path, n):
    if not os.path.exists(path):
        return []
    with open(path, "rb") as f:
        f.seek(0, 2); size = f.tell(); block = 64 * 1024; data = b""
        while size > 0 and data.count(b"\n") <= n:
            rd = min(block, size); size -= rd; f.seek(size); data = f.read(rd) + data
        lines = data.splitlines()[-n:]
    out = []
    for ln in lines:
        ln = ln.strip()
        if not ln:
            continue
        try:
            out.append(json.loads(ln))
        except json.JSONDecodeError:
            pass
    return out

def parse_ts(s):
    if not s:
        return None
    try:
        if isinstance(s, str) and s.endswith("Z"):
            s = s[:-1] + "+00:00"
        dt = datetime.fromisoformat(s)
        return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
    except (ValueError, TypeError):
        return None

# --- detection (swap point #1) -------------------------------------------------

def read_usage(path, window):
    rows = tail_jsonl(path, max(window, 12))
    evaluable = [r for r in rows
                 if r.get("error") is not True and isinstance(r.get("day_pct"), (int, float))]
    return evaluable

def evaluate_breach(evaluable, now, args):
    """Returns (breached: bool, reason: str|None, latest_row|None)."""
    if not evaluable:
        return (False, None, None)
    latest = evaluable[-1]
    latest_ts = parse_ts(latest.get("timestamp"))
    if latest_ts is None or (now - latest_ts).total_seconds() > args.stale_secs:
        log(f"SKIP stale/unparseable latest row ts={latest.get('timestamp')}")
        return (False, None, latest)
    day = float(latest["day_pct"])
    # hard % ceiling
    if day >= args.hard_ceiling_pct:
        return (True, f"hard_ceiling day_pct={day:.0f}% >= {args.hard_ceiling_pct:.0f}%", latest)
    # rate / slope
    window = evaluable[-args.rate_window:]
    if len(window) >= 2:
        rise = float(window[-1]["day_pct"]) - float(window[0]["day_pct"])
        if rise >= args.rate_delta_pct and day >= args.rate_floor_pct:
            return (True, f"rate day_pct +{rise:.0f}% over {len(window)} rows (latest {day:.0f}% >= floor {args.rate_floor_pct:.0f}%)", latest)
    return (False, None, latest)

# --- action (swap point #2) ----------------------------------------------------

def terminate_running_sessions(store_path, apply, now_ms):
    """Terminate sessions with status=='running' (prefer main lane), preserving key/identity.
    Returns list of (session_key, sessionId, prev_status) acted (or would-act)."""
    if not os.path.exists(store_path):
        log(f"SESSION_STORE missing: {store_path}")
        return []
    store = json.load(open(store_path))
    targets = []
    for key, ent in store.items():
        if isinstance(ent, dict) and ent.get("status") == "running":
            targets.append(key)
    # main lane first (the THR-108 vector), then others
    targets.sort(key=lambda k: (0 if k == "agent:main:main" else 1, k))
    acted = []
    if targets and apply:
        bak = f"{store_path}.bak-cost-actuator-{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')}"
        shutil.copy2(store_path, bak)
        log(f"SESSION_STORE backed up -> {bak}")
    for key in targets:
        ent = store[key]
        prev = ent.get("status")
        acted.append((key, str(ent.get("sessionId", ""))[:18], prev))
        if apply:
            ent["status"] = "done"          # Incident Protocol: preserve key + identity
            ent["abortedLastRun"] = True
            ent["endedAt"] = now_ms
    if targets and apply:
        tmp = store_path + ".tmp"
        with open(tmp, "w") as f:
            json.dump(store, f, indent=2)
        os.replace(tmp, store_path)
        log(f"TERMINATED {len(acted)} running session(s): {[a[0] for a in acted]}")
    return acted

def send_alert(msg, apply, dry_run_tg):
    if not apply or dry_run_tg:
        log(f"ALERT (suppressed: apply={apply} dry_run_tg={dry_run_tg}): {msg}")
        return
    try:
        r = subprocess.run([OPENCLAW_BIN, "message", "send", "--channel", "telegram",
                            "--target", TELEGRAM_TARGET, "--message", msg],
                           capture_output=True, text=True, timeout=20)
        log(f"ALERT delivered rc={r.returncode}")
    except Exception as e:
        log(f"ALERT EXCEPTION {type(e).__name__}: {e}")

# --- cooldown ------------------------------------------------------------------

def within_cooldown(sentinel, now, secs):
    if not os.path.exists(sentinel):
        return False
    try:
        d = json.load(open(sentinel)); t = parse_ts(d.get("ts"))
        return t is not None and (now - t).total_seconds() < secs
    except (OSError, json.JSONDecodeError):
        return False

def write_sentinel(sentinel, payload):
    os.makedirs(os.path.dirname(sentinel), exist_ok=True)
    tmp = sentinel + ".tmp"
    with open(tmp, "w") as f:
        json.dump(payload, f, indent=2)
    os.replace(tmp, sentinel)

# --- main ----------------------------------------------------------------------

def run(args):
    global _LOG
    _LOG = args.log_path
    now = datetime.now(timezone.utc)
    now_ms = int(time.time() * 1000)
    evaluable = read_usage(args.usage_jsonl, args.rate_window)
    breached, reason, latest = evaluate_breach(evaluable, now, args)
    if not breached:
        day = latest.get("day_pct") if latest else "n/a"
        log(f"OK no breach (latest day_pct={day})")
        print(f"cost-actuator: no breach (day_pct={day})")
        return 0
    if within_cooldown(args.sentinel, now, args.cooldown_secs):
        log(f"SUPPRESS breach within cooldown; reason={reason}")
        print("cost-actuator: breach but within cooldown")
        return 0
    acted = terminate_running_sessions(args.session_store, args.apply, now_ms)
    mode = "APPLY" if args.apply else "DRY-RUN"
    if acted:
        body = f"[cost-ceiling] {mode} BREACH: {reason}. Terminated {len(acted)} running session(s): " \
               f"{', '.join(k for k, _, _ in acted)} (keys preserved). owner:Dorian"
    else:
        body = f"[cost-ceiling] {mode} BREACH: {reason}. No running session on this host to terminate " \
               f"(shared-bucket drain may be another consumer). owner:Dorian"
    log(f"FIRE {mode} reason={reason} acted={acted}")
    send_alert(body, args.apply, args.telegram_dry_run)
    if args.apply:
        write_sentinel(args.sentinel, {"ts": now.isoformat(), "reason": reason,
                                       "acted": [k for k, _, _ in acted]})
    print(f"cost-actuator: {mode} fired — {reason}; acted={[k for k,_,_ in acted]}")
    return 10 if args.apply and acted else 0

def main(argv=None):
    p = argparse.ArgumentParser()
    p.add_argument("--usage-jsonl", default=USAGE_JSONL)
    p.add_argument("--session-store", default=SESSION_STORE)
    p.add_argument("--sentinel", default=SENTINEL)
    p.add_argument("--log-path", default=LOG_PATH)
    p.add_argument("--apply", action="store_true", help="actually terminate + alert (default: dry-run)")
    p.add_argument("--telegram-dry-run", action="store_true")
    p.add_argument("--hard-ceiling-pct", type=float, default=HARD_CEILING_PCT)
    p.add_argument("--rate-window", type=int, default=RATE_WINDOW)
    p.add_argument("--rate-delta-pct", type=float, default=RATE_DELTA_PCT)
    p.add_argument("--rate-floor-pct", type=float, default=RATE_FLOOR_PCT)
    p.add_argument("--stale-secs", type=int, default=STALE_SECS)
    p.add_argument("--cooldown-secs", type=int, default=COOLDOWN_SECS)
    return run(p.parse_args(argv))

if __name__ == "__main__":
    try:
        sys.exit(main())
    except Exception as e:
        try:
            log(f"FATAL {type(e).__name__}: {e}")
        except Exception:
            pass
        sys.exit(2)
