# edited by glg
from typing import Any, Dict, List

from pypos.core.utils.config_utils import read_config
from pypos.modules.platform_ops.services.fleet_rollout_orchestrator_service import (
    FleetRolloutOrchestratorService,
)


class FleetRolloutRuntimeGuardService:
    _NON_FATAL_EXCEPTIONS = (
        TypeError,
        ValueError,
        KeyError,
        AttributeError,
        RuntimeError,
        OSError,
        LookupError,
        ArithmeticError,
        ImportError,
    )

    def __init__(self, config_reader=read_config, orchestrator_service=None):
        self.config_reader = config_reader
        self.orchestrator_service = orchestrator_service or FleetRolloutOrchestratorService()

    @staticmethod
    def _to_bool(value: Any, default: bool = False) -> bool:
        if isinstance(value, bool):
            return bool(value)
        if value is None:
            return bool(default)
        if isinstance(value, (int, float)):
            return float(value) > 0.0
        text = str(value).strip().lower()
        if text in {"1", "true", "yes", "on"}:
            return True
        if text in {"0", "false", "no", "off"}:
            return False
        return bool(default)

    @staticmethod
    def _to_int(value: Any, default: int = 0, minimum: int = 0) -> int:
        try:
            parsed = int(value)
        except (TypeError, ValueError):
            parsed = int(default)
        return max(int(minimum), int(parsed))

    @staticmethod
    def _to_float(value: Any, default: float = 0.0, minimum: float = 0.0) -> float:
        try:
            parsed = float(value)
        except (TypeError, ValueError):
            parsed = float(default)
        return max(float(minimum), float(parsed))

    @staticmethod
    def _normalize_wave_percents(raw: Any) -> List[int]:
        values: List[int] = []
        if isinstance(raw, str):
            candidates = [part.strip() for part in raw.split(",")]
        elif isinstance(raw, (list, tuple, set)):
            candidates = list(raw)
        else:
            candidates = []
        seen = set()
        for item in candidates:
            try:
                val = int(item)
            except (TypeError, ValueError):
                continue
            val = max(1, min(100, int(val)))
            if val in seen:
                continue
            seen.add(val)
            values.append(val)
        if not values:
            return [5, 20, 50, 100]
        return sorted(values)

    @staticmethod
    def _normalize_branch_ids(raw: Any) -> List[int]:
        if isinstance(raw, str):
            candidates = [part.strip() for part in raw.split(",")]
        elif isinstance(raw, (list, tuple, set)):
            candidates = list(raw)
        else:
            candidates = []
        out: List[int] = []
        seen = set()
        for item in candidates:
            try:
                val = int(item)
            except (TypeError, ValueError):
                continue
            if val <= 0 or val in seen:
                continue
            seen.add(val)
            out.append(val)
        return sorted(out)

    def _read_policy(self) -> Dict[str, Any]:
        cfg = self.config_reader() if callable(self.config_reader) else {}
        source = cfg if isinstance(cfg, dict) else {}
        return {
            "enabled": self._to_bool(source.get("fleet_rollout_runtime_guard_enabled"), default=True),
            "fail_open": self._to_bool(source.get("fleet_rollout_runtime_fail_open"), default=True),
            "halt_on_unhealthy": self._to_bool(source.get("fleet_rollout_runtime_halt_on_unhealthy"), default=True),
            "config_version": str(source.get("fleet_rollout_runtime_config_version") or "runtime-local").strip(),
            "wave_name": str(source.get("fleet_rollout_runtime_wave_name") or "canary").strip(),
            "canary_percent": self._to_int(
                source.get("fleet_rollout_runtime_canary_percent"), default=1, minimum=1
            ),
            "wave_percents": self._normalize_wave_percents(source.get("fleet_rollout_runtime_wave_percents")),
            "branch_ids": self._normalize_branch_ids(source.get("fleet_rollout_runtime_branch_ids")),
            "fail_threshold_pct": self._to_float(
                source.get("fleet_rollout_runtime_fail_threshold_pct"), default=2.0, minimum=0.1
            ),
            "latency_threshold_ms": self._to_float(
                source.get("fleet_rollout_runtime_latency_threshold_ms"),
                default=2000.0,
                minimum=100.0,
            ),
            "min_sample_count": self._to_int(
                source.get("fleet_rollout_runtime_min_sample_count"),
                default=0,
                minimum=0,
            ),
            "fail_when_sample_insufficient": self._to_bool(
                source.get("fleet_rollout_runtime_fail_when_sample_insufficient"),
                default=False,
            ),
        }

    @staticmethod
    def _normalize_metrics_payload(payload: Any) -> Dict[str, Any]:
        data = payload if isinstance(payload, dict) else {}
        return {
            "error_rate_pct": float(data.get("error_rate_pct", 0.0) or 0.0),
            "p95_latency_ms": float(data.get("p95_latency_ms", 0.0) or 0.0),
            "sample_count": int(data.get("sample_count", data.get("attempt_sample_count", 0)) or 0),
            "queue_pending": int(data.get("queue_pending", 0) or 0),
            "queue_inflight": int(data.get("queue_inflight", 0) or 0),
            "source": str(data.get("source") or "").strip(),
        }

    def evaluate_export_guard(self, *, branch_id: Any, metrics_payload: Any) -> Dict[str, Any]:
        policy = self._read_policy()
        branch_id_int = self._to_int(branch_id, default=0, minimum=0)
        if not bool(policy.get("enabled")):
            return {
                "allow": True,
                "reason": "runtime_guard_disabled",
                "decision": "continue",
                "healthy": True,
                "in_scope": False,
                "branch_id": branch_id_int,
                "policy": policy,
                "wave": {},
                "gate": {},
            }

        metrics = self._normalize_metrics_payload(metrics_payload)
        branch_ids = list(policy.get("branch_ids") or [])
        if branch_id_int > 0 and branch_id_int not in branch_ids:
            branch_ids.append(branch_id_int)
            branch_ids = sorted(set(branch_ids))
        if not branch_ids:
            return {
                "allow": bool(policy.get("fail_open")),
                "reason": "missing_branch_scope",
                "decision": "continue" if bool(policy.get("fail_open")) else "halt_and_rollback",
                "healthy": bool(policy.get("fail_open")),
                "in_scope": False,
                "branch_id": branch_id_int,
                "policy": policy,
                "wave": {},
                "gate": {},
            }

        try:
            plan = self.orchestrator_service.build_plan(
                branch_ids=branch_ids,
                config_version=policy.get("config_version"),
                canary_percent=policy.get("canary_percent"),
                wave_percents=policy.get("wave_percents"),
            )
            evaluation = self.orchestrator_service.evaluate_wave(
                plan=plan,
                wave_name=policy.get("wave_name"),
                metrics_payload=metrics,
                fail_threshold_pct=policy.get("fail_threshold_pct"),
                latency_threshold_ms=policy.get("latency_threshold_ms"),
            )
            wave = evaluation.get("wave") if isinstance(evaluation.get("wave"), dict) else {}
            wave_branch_ids = [
                int(x)
                for x in list(wave.get("branch_ids") or [])
                if self._to_int(x, default=0, minimum=0) > 0
            ]
            in_scope = bool(branch_id_int <= 0 or not wave_branch_ids or branch_id_int in wave_branch_ids)
            healthy = bool(evaluation.get("healthy"))
            decision = str(evaluation.get("decision") or "halt_and_rollback").strip()
            sample_count = self._to_int(
                (evaluation.get("gate") or {}).get("sample_count", metrics.get("sample_count", 0)),
                default=0,
                minimum=0,
            )
            min_samples = self._to_int(policy.get("min_sample_count"), default=0, minimum=0)
            insufficient_sample = bool(sample_count < min_samples)
            if insufficient_sample and bool(policy.get("fail_when_sample_insufficient")):
                healthy = False
                decision = "halt_and_rollback"

            allow = True
            reason = "healthy"
            if in_scope and bool(policy.get("halt_on_unhealthy")) and not healthy:
                allow = False
                reason = "unhealthy_wave_metrics"
            elif in_scope and insufficient_sample:
                reason = "insufficient_sample"
            elif not in_scope:
                reason = "branch_not_in_wave_scope"

            return {
                "allow": bool(allow),
                "reason": reason,
                "decision": decision if decision else ("continue" if allow else "halt_and_rollback"),
                "healthy": bool(healthy),
                "in_scope": bool(in_scope),
                "branch_id": branch_id_int,
                "policy": policy,
                "wave": wave,
                "gate": evaluation.get("gate") if isinstance(evaluation.get("gate"), dict) else {},
                "sample_count": int(sample_count),
                "insufficient_sample": bool(insufficient_sample),
            }
        except self._NON_FATAL_EXCEPTIONS as exc:
            fail_open = bool(policy.get("fail_open"))
            return {
                "allow": bool(fail_open),
                "reason": "runtime_guard_error",
                "decision": "continue" if fail_open else "halt_and_rollback",
                "healthy": bool(fail_open),
                "in_scope": True,
                "branch_id": branch_id_int,
                "policy": policy,
                "wave": {},
                "gate": {},
                "error": str(exc or ""),
            }
