#!/usr/bin/env python3
"""RFC-001 pilot drift checks (advisory only)."""

from __future__ import annotations

import argparse
import hashlib
import json
import re
import subprocess
import time
from pathlib import Path
from typing import Dict, List

from lifecycle_audit import append_entry, build_entry

WORKSPACE = Path("/Users/openclaw/.openclaw/workspace")
REGISTRY_PATH = WORKSPACE / "docs/state/rfc001-canonical-lanes.json"
ADVISORY_JSONL = WORKSPACE / "tmp/ops-maintainer-advisories.jsonl"
ADVISORY_LOG = WORKSPACE / "tmp/ops-maintainer-advisories.log"
DEDUPE_STATE = WORKSPACE / "tmp/ops-maintainer-advisory-state.json"


def load_json(path: Path):
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def load_session_keys() -> List[str]:
    keys: List[str] = []
    for p in [
        Path("/Users/openclaw/.openclaw/agents/main/sessions/sessions.json"),
        Path("/Users/openclaw/.openclaw/agents/worker/sessions/sessions.json"),
        Path("/Users/openclaw/.openclaw/agents/ops-maintainer/sessions/sessions.json"),
    ]:
        if p.exists():
            data = load_json(p)
            keys.extend(data.keys())
    return sorted(set(keys))


def run_cron_list_json() -> tuple[List[dict], str | None]:
    try:
        out = subprocess.check_output(["openclaw", "cron", "list", "--json"], text=True, stderr=subprocess.STDOUT)
        data = json.loads(out)
        return (data if isinstance(data, list) else [], None)
    except Exception as exc:
        return ([], str(exc))


def advisory_fingerprint(code: str, details: dict) -> str:
    blob = json.dumps({"code": code, "details": details}, sort_keys=True)
    return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:24]


def should_emit(fp: str, cooldown_seconds: int) -> bool:
    now = int(time.time())
    state = {}
    if DEDUPE_STATE.exists():
        state = load_json(DEDUPE_STATE)
    last = int(state.get(fp, 0))
    if now - last < cooldown_seconds:
        return False
    state[fp] = now
    DEDUPE_STATE.parent.mkdir(parents=True, exist_ok=True)
    DEDUPE_STATE.write_text(json.dumps(state, indent=2), encoding="utf-8")
    return True


def emit_advisory(code: str, message: str, details: dict, cooldown_seconds: int, lane: str = "agent:ops-maintainer:main"):
    fp = advisory_fingerprint(code, details)
    if not should_emit(fp, cooldown_seconds):
        return False

    advisory = {
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "mode": "advisory",
        "route": lane,
        "code": code,
        "message": message,
        "fingerprint": fp,
        "details": details,
    }
    ADVISORY_JSONL.parent.mkdir(parents=True, exist_ok=True)
    with ADVISORY_JSONL.open("a", encoding="utf-8") as f:
        f.write(json.dumps(advisory) + "\n")
    with ADVISORY_LOG.open("a", encoding="utf-8") as f:
        f.write(f"[{advisory['timestamp']}] {code} {message} fp={fp}\n")

    entry = build_entry(
        lane=lane,
        session_key=lane,
        action_type="advisory",
        trigger_reason=code,
        actor="watchdog",
        details=details,
    )
    append_entry(entry)
    return True


def check_multiple_canonical(registry: dict, keys: List[str]) -> List[dict]:
    findings = []
    for role, cfg in registry["roles"].items():
        canonical = cfg.get("canonicalSessionKey")
        if canonical:
            count = keys.count(canonical)
            if count > 1:
                findings.append({"code": "MULTIPLE_CANONICAL_CANDIDATES", "role": role, "canonical": canonical, "count": count})
    return findings


def check_noncanonical_anomalies(registry: dict, keys: List[str]) -> List[dict]:
    findings = []
    main_allowed_patterns: List[str] = []
    for cfg in registry.get("roles", {}).values():
        for pat in cfg.get("allowedPatterns", []):
            if isinstance(pat, str) and pat.startswith("^agent:main:"):
                main_allowed_patterns.append(pat)

    for key in keys:
        # Detect main lane keys that don't match any allowed main-lane pattern from registry roles.
        if key.startswith("agent:main:") and not key.startswith("agent:main:cron:"):
            ok = any(re.match(pat, key) for pat in main_allowed_patterns)
            if not ok:
                findings.append({"code": "NON_CANONICAL_LANE_ANOMALY", "sessionKey": key, "reason": "main lane key outside canonical allowlist"})
    return findings


def check_cron_targeting(registry: dict, cron_jobs: List[dict]) -> List[dict]:
    findings = []
    approved = [re.compile(p) for p in registry.get("cron", {}).get("approvedSessionPatterns", [])]
    disallowed = [re.compile(p) for p in registry.get("cron", {}).get("disallowedSessionPatterns", [])]

    for job in cron_jobs:
        sk = job.get("sessionKey") or ""
        if not sk:
            continue
        if any(p.search(sk) for p in disallowed):
            findings.append({"code": "CRON_TARGETS_NON_APPROVED_LANE", "jobId": job.get("id"), "name": job.get("name"), "sessionKey": sk})
            continue
        if approved and not any(p.search(sk) for p in approved):
            findings.append({"code": "CRON_TARGETS_NON_APPROVED_LANE", "jobId": job.get("id"), "name": job.get("name"), "sessionKey": sk})
    return findings


def main() -> int:
    p = argparse.ArgumentParser()
    p.add_argument("--cooldown-minutes", type=int, default=45)
    p.add_argument("--no-emit", action="store_true")
    args = p.parse_args()

    registry = load_json(REGISTRY_PATH)
    keys = load_session_keys()
    cron_jobs, cron_error = run_cron_list_json()

    findings: List[dict] = []
    findings.extend(check_multiple_canonical(registry, keys))
    findings.extend(check_noncanonical_anomalies(registry, keys))
    findings.extend(check_cron_targeting(registry, cron_jobs))
    if cron_error:
        findings.append({"code": "CRON_VISIBILITY_GAP", "reason": "unable to fetch cron list via openclaw CLI", "error": cron_error})

    result = {
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "mode": "advisory",
        "registry": str(REGISTRY_PATH),
        "sessionKeyCount": len(keys),
        "cronJobCount": len(cron_jobs),
        "findingCount": len(findings),
        "findings": findings,
    }
    print(json.dumps(result, indent=2))

    if not args.no_emit:
        cooldown_seconds = max(30, min(args.cooldown_minutes, 60)) * 60
        for f in findings:
            emit_advisory(
                code=f["code"],
                message=f"RFC-001 drift check advisory: {f['code']}",
                details=f,
                cooldown_seconds=cooldown_seconds,
            )

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
