#!/usr/bin/env python3
"""cost-ceiling-actuator.py — THR-109 Inc 3c: B2 ACTUATING cost ceiling.

The RFC-002 §1 circuit-breaker only ALERTS on 5h-bucket cost (it has no kill path).
This ADDS the actuation: on a hard %/rate breach of the 5-hour agentic meter, it
TERMINATES the offending running session(s) on THIS host — preserving the session key
per the Incident Protocol (THR-107: status->done, abortedLastRun=true, endedAt=now) —
and fires a fail-loud alert.

SWAPPABLE BY DESIGN: detection (read_usage / evaluate_breach) and action
(terminate_running_sessions) are separate. When a native
`agents.guardrails.maxCostPerSession` ships (tracked as THR-110 upgrade), delete this
script + its plist; nothing else depends on it.

Data source: oauth-usage-history.jsonl (the same 5h-meter §1 collects; day_pct =
the 5h bucket's usedPercent per oauth_usage_collect.py:92).

SAFETY: default DRY-RUN — reports what it WOULD do, no termination, no real alert.
Pass --apply to actually terminate + alert. Per-host only (terminates sessions that
are `running` in the local store); a shared-bucket drain from another consumer is
reported but cannot be terminated from here.
"""

import argparse
import json
import os
import shutil
import subprocess
import sys
import time
from datetime import datetime, timezone

USAGE_JSONL   = "/Users/openclaw/.openclaw/workspace/state/oauth-usage-history.jsonl"
SESSION_STORE = "/Users/openclaw/.openclaw/agents/main/sessions/sessions.json"
SENTINEL      = "/Users/openclaw/.openclaw/workspace/state/cost-actuator-last-fire.flag"
LOG_PATH      = "/tmp/openclaw/cost-ceiling-actuator.log"
OPENCLAW_BIN  = shutil.which("openclaw") or "/opt/homebrew/bin/openclaw"
TELEGRAM_TARGET = "8032472383"

# --- thresholds (overridable for tests) ---
HARD_CEILING_PCT = 95.0   # day_pct >= this -> hard breach (near-exhaustion; well past §1 CRITICAL 80)
RATE_WINDOW      = 4       # evaluable rows (~20 min @ 5-min cadence) for the slope check
RATE_DELTA_PCT   = 30.0    # >= this rise across the window -> rate breach. THR-108's steepest
                           # 20-min rise was ~34-38pts, so 30 catches that signature; secondary
                           # to the hard ceiling (the §1 breaker already ALERTS on slope).
RATE_FLOOR_PCT   = 70.0    # ...only when already >= this consumed (near-top steep climb only;
                           # avoids terminating a mid-level burst that may still level off).
STALE_SECS       = 1800    # latest row older than this -> do not act (cannot trust)
COOLDOWN_SECS    = 3600    # do not re-fire within this window

# B5 (THR-112) — consecutive-blindness alert. Fires regardless of --apply (force=True):
# a blind monitor must never sit silent, INCL. during shadow (the THR-112 81-min blind
# gap sat silent precisely because shadow is quiet). Independent of COOLDOWN_SECS. The
# blindness alert is a health ping only — it never terminates (actuation still needs --apply).
N_CONSECUTIVE_BLIND = 3    # blind reads in a row before first alert (~15 min @5-min)
BLIND_REFIRE_EVERY  = 3    # re-fire every N further consecutive blind reads
BLINDNESS_STATE     = "/Users/openclaw/.openclaw/workspace/state/cost-actuator-blindness-streak.json"

_LOG = LOG_PATH

def log(msg):
    os.makedirs(os.path.dirname(_LOG), exist_ok=True)
    ts = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
    with open(_LOG, "a") as f:
        f.write(f"[{ts}] {msg}\n")

def tail_jsonl(path, n):
    if not os.path.exists(path):
        return []
    with open(path, "rb") as f:
        f.seek(0, 2); size = f.tell(); block = 64 * 1024; data = b""
        while size > 0 and data.count(b"\n") <= n:
            rd = min(block, size); size -= rd; f.seek(size); data = f.read(rd) + data
        lines = data.splitlines()[-n:]
    out = []
    for ln in lines:
        ln = ln.strip()
        if not ln:
            continue
        try:
            out.append(json.loads(ln))
        except json.JSONDecodeError:
            pass
    return out

def parse_ts(s):
    if not s:
        return None
    try:
        if isinstance(s, str) and s.endswith("Z"):
            s = s[:-1] + "+00:00"
        dt = datetime.fromisoformat(s)
        return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
    except (ValueError, TypeError):
        return None

# --- detection (swap point #1) -------------------------------------------------

def read_usage(path, window):
    rows = tail_jsonl(path, max(window, 12))
    evaluable = [r for r in rows
                 if r.get("error") is not True and isinstance(r.get("day_pct"), (int, float))]
    return evaluable

def evaluate_breach(evaluable, now, args):
    """Returns (breached: bool, reason: str|None, latest_row|None)."""
    if not evaluable:
        return (False, None, None)
    latest = evaluable[-1]
    latest_ts = parse_ts(latest.get("timestamp"))
    if latest_ts is None or (now - latest_ts).total_seconds() > args.stale_secs:
        log(f"SKIP stale/unparseable latest row ts={latest.get('timestamp')}")
        return (False, None, latest)
    day = float(latest["day_pct"])
    # hard % ceiling
    if day >= args.hard_ceiling_pct:
        return (True, f"hard_ceiling day_pct={day:.0f}% >= {args.hard_ceiling_pct:.0f}%", latest)
    # rate / slope
    window = evaluable[-args.rate_window:]
    if len(window) >= 2:
        rise = float(window[-1]["day_pct"]) - float(window[0]["day_pct"])
        if rise >= args.rate_delta_pct and day >= args.rate_floor_pct:
            return (True, f"rate day_pct +{rise:.0f}% over {len(window)} rows (latest {day:.0f}% >= floor {args.rate_floor_pct:.0f}%)", latest)
    return (False, None, latest)

# --- action (swap point #2) ----------------------------------------------------

def terminate_running_sessions(store_path, apply, now_ms):
    """Terminate sessions with status=='running' (prefer main lane), preserving key/identity.
    Returns list of (session_key, sessionId, prev_status) acted (or would-act)."""
    if not os.path.exists(store_path):
        log(f"SESSION_STORE missing: {store_path}")
        return []
    store = json.load(open(store_path))
    targets = []
    for key, ent in store.items():
        if isinstance(ent, dict) and ent.get("status") == "running":
            targets.append(key)
    # main lane first (the THR-108 vector), then others
    targets.sort(key=lambda k: (0 if k == "agent:main:main" else 1, k))
    acted = []
    if targets and apply:
        bak = f"{store_path}.bak-cost-actuator-{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')}"
        shutil.copy2(store_path, bak)
        log(f"SESSION_STORE backed up -> {bak}")
    for key in targets:
        ent = store[key]
        prev = ent.get("status")
        acted.append((key, str(ent.get("sessionId", ""))[:18], prev))
        if apply:
            ent["status"] = "done"          # Incident Protocol: preserve key + identity
            ent["abortedLastRun"] = True
            ent["endedAt"] = now_ms
    if targets and apply:
        tmp = store_path + ".tmp"
        with open(tmp, "w") as f:
            json.dump(store, f, indent=2)
        os.replace(tmp, store_path)
        log(f"TERMINATED {len(acted)} running session(s): {[a[0] for a in acted]}")
    return acted

def send_alert(msg, apply, dry_run_tg, force=False):
    # force=True (B5 blindness health-ping) bypasses the --apply gate so a blind actuator
    # reports even in shadow; dry_run_tg still suppresses the actual Telegram send.
    if (not apply and not force) or dry_run_tg:
        log(f"ALERT (suppressed: apply={apply} force={force} dry_run_tg={dry_run_tg}): {msg}")
        return
    try:
        r = subprocess.run([OPENCLAW_BIN, "message", "send", "--channel", "telegram",
                            "--target", TELEGRAM_TARGET, "--message", msg],
                           capture_output=True, text=True, timeout=20)
        log(f"ALERT delivered rc={r.returncode}")
    except Exception as e:
        log(f"ALERT EXCEPTION {type(e).__name__}: {e}")

# --- cooldown ------------------------------------------------------------------

def within_cooldown(sentinel, now, secs):
    if not os.path.exists(sentinel):
        return False
    try:
        d = json.load(open(sentinel)); t = parse_ts(d.get("ts"))
        return t is not None and (now - t).total_seconds() < secs
    except (OSError, json.JSONDecodeError):
        return False

def write_sentinel(sentinel, payload):
    os.makedirs(os.path.dirname(sentinel), exist_ok=True)
    tmp = sentinel + ".tmp"
    with open(tmp, "w") as f:
        json.dump(payload, f, indent=2)
    os.replace(tmp, sentinel)

# --- B5 (THR-112): consecutive-blindness tracking ------------------------------

def _is_blind(evaluable, now, args):
    """Blind = no evaluable rows, OR the latest evaluable reading is stale."""
    if not evaluable:
        return True
    latest_ts = parse_ts(evaluable[-1].get("timestamp"))
    return latest_ts is None or (now - latest_ts).total_seconds() > args.stale_secs

def _read_blind_state(path):
    try:
        d = json.load(open(path))
        return {"count": int(d.get("count", 0)), "first": d.get("first"),
                "last_alert": int(d.get("last_alert", 0))}
    except (OSError, json.JSONDecodeError, ValueError, TypeError):
        return {"count": 0, "first": None, "last_alert": 0}

def update_blind_streak(path, blind, now):
    """Consecutive blind reads. Returns (count, should_alert, first). Fires at 3,6,9,…;
    independent of COOLDOWN_SECS. A single non-blind read resets to 0."""
    st = _read_blind_state(path)
    if not blind:
        if st["count"]:
            write_sentinel(path, {"count": 0, "first": None, "last_alert": 0})
        return (0, False, None)
    count = st["count"] + 1
    first = st["first"] or now.isoformat()
    last_alert = st["last_alert"]
    should = count >= N_CONSECUTIVE_BLIND and (count - last_alert) >= BLIND_REFIRE_EVERY
    if should:
        last_alert = count
    write_sentinel(path, {"count": count, "first": first, "last_alert": last_alert})
    return (count, should, first)

# --- main ----------------------------------------------------------------------

def run(args):
    global _LOG
    _LOG = args.log_path
    now = datetime.now(timezone.utc)
    now_ms = int(time.time() * 1000)
    evaluable = read_usage(args.usage_jsonl, args.rate_window)
    breached, reason, latest = evaluate_breach(evaluable, now, args)

    # B5 (THR-112): blindness tracking — independent of cooldown AND of --apply, so a
    # blind actuator (incl. in shadow) cannot sit silent (the THR-112 81-min blind gap).
    blind = _is_blind(evaluable, now, args)
    bcount, balert, bfirst = update_blind_streak(args.blindness_state, blind, now)
    if balert:
        mins = bcount * 5  # ~5-min cadence
        log(f"FIRE BLINDNESS-ALERT consecutive={bcount} (~{mins}min) not_cooldown_gated")
        send_alert(f"[cost-ceiling] MONITOR BLIND {bcount} consecutive reads (~{mins} min) — "
                   f"no fresh evaluable 5h-meter data since {bfirst}; cannot evaluate breach. "
                   f"owner:Dorian", args.apply, args.telegram_dry_run, force=True)

    if not breached:
        day = latest.get("day_pct") if latest else "n/a"
        log(f"OK no breach (latest day_pct={day})")
        print(f"cost-actuator: no breach (day_pct={day})")
        return 0
    if within_cooldown(args.sentinel, now, args.cooldown_secs):
        log(f"SUPPRESS breach within cooldown; reason={reason}")
        print("cost-actuator: breach but within cooldown")
        return 0
    acted = terminate_running_sessions(args.session_store, args.apply, now_ms)
    mode = "APPLY" if args.apply else "DRY-RUN"
    if acted:
        body = f"[cost-ceiling] {mode} BREACH: {reason}. Terminated {len(acted)} running session(s): " \
               f"{', '.join(k for k, _, _ in acted)} (keys preserved). owner:Dorian"
    else:
        body = f"[cost-ceiling] {mode} BREACH: {reason}. No running session on this host to terminate " \
               f"(shared-bucket drain may be another consumer). owner:Dorian"
    log(f"FIRE {mode} reason={reason} acted={acted}")
    send_alert(body, args.apply, args.telegram_dry_run)
    if args.apply:
        write_sentinel(args.sentinel, {"ts": now.isoformat(), "reason": reason,
                                       "acted": [k for k, _, _ in acted]})
    print(f"cost-actuator: {mode} fired — {reason}; acted={[k for k,_,_ in acted]}")
    return 10 if args.apply and acted else 0

def main(argv=None):
    p = argparse.ArgumentParser()
    p.add_argument("--usage-jsonl", default=USAGE_JSONL)
    p.add_argument("--session-store", default=SESSION_STORE)
    p.add_argument("--sentinel", default=SENTINEL)
    p.add_argument("--blindness-state", default=BLINDNESS_STATE)
    p.add_argument("--log-path", default=LOG_PATH)
    p.add_argument("--apply", action="store_true", help="actually terminate + alert (default: dry-run)")
    p.add_argument("--telegram-dry-run", action="store_true")
    p.add_argument("--hard-ceiling-pct", type=float, default=HARD_CEILING_PCT)
    p.add_argument("--rate-window", type=int, default=RATE_WINDOW)
    p.add_argument("--rate-delta-pct", type=float, default=RATE_DELTA_PCT)
    p.add_argument("--rate-floor-pct", type=float, default=RATE_FLOOR_PCT)
    p.add_argument("--stale-secs", type=int, default=STALE_SECS)
    p.add_argument("--cooldown-secs", type=int, default=COOLDOWN_SECS)
    return run(p.parse_args(argv))

if __name__ == "__main__":
    try:
        sys.exit(main())
    except Exception as e:
        try:
            log(f"FATAL {type(e).__name__}: {e}")
        except Exception:
            pass
        sys.exit(2)
