from typing import Dict, List

# edited by glg


class FleetRolloutService:
    @staticmethod
    def _normalize_ids(branch_ids: List[int]) -> List[int]:
        out = []
        seen = set()
        for raw in branch_ids or []:
            try:
                val = int(raw)
            except (TypeError, ValueError, KeyError, AttributeError, RuntimeError, OSError, LookupError, ArithmeticError, ImportError):
                continue
            if val <= 0 or val in seen:
                continue
            seen.add(val)
            out.append(val)
        return sorted(out)

    def build_rollout_plan(
        self,
        *,
        branch_ids: List[int],
        config_version: str,
        canary_percent: int = 1,
        wave_percents: List[int] = None,
    ) -> Dict:
        ids = self._normalize_ids(branch_ids)
        total = len(ids)
        canary_pct = max(1, min(int(canary_percent or 1), 20))
        stages = list(wave_percents or [5, 20, 50, 100])
        stages = [max(1, min(int(x or 0), 100)) for x in stages]
        stages = sorted(set(stages))
        if stages[-1] != 100:
            stages.append(100)

        canary_target = max(1, int(round(total * (canary_pct / 100.0)))) if total > 0 else 0
        canary_ids = ids[:canary_target]
        remaining = ids[canary_target:]

        waves = []
        covered = canary_target
        for pct in stages:
            target = int(round(total * (pct / 100.0)))
            if pct == 100:
                target = total
            target = max(target, canary_target)
            wave_size = max(0, target - covered)
            if wave_size <= 0:
                continue
            wave_ids = remaining[:wave_size]
            remaining = remaining[wave_size:]
            covered += len(wave_ids)
            waves.append(
                {
                    "name": f"wave_{pct}",
                    "target_percent": int(pct),
                    "branch_ids": wave_ids,
                    "size": len(wave_ids),
                }
            )

        return {
            "config_version": str(config_version or "").strip(),
            "total_branches": total,
            "canary_percent": canary_pct,
            "canary": {
                "branch_ids": canary_ids,
                "size": len(canary_ids),
            },
            "waves": waves,
            "unassigned": remaining,
        }

    @staticmethod
    def evaluate_wave_health(metrics: Dict, fail_threshold_pct: float = 2.0) -> Dict:
        data = dict(metrics or {})
        error_pct = float(data.get("error_rate_pct", 0.0) or 0.0)
        p95_ms = float(data.get("p95_latency_ms", 0.0) or 0.0)
        fail_threshold = max(0.1, float(fail_threshold_pct or 2.0))
        healthy = error_pct <= fail_threshold and p95_ms <= 2000.0
        return {
            "healthy": bool(healthy),
            "error_rate_pct": error_pct,
            "p95_latency_ms": p95_ms,
            "action": "continue" if healthy else "halt_and_rollback",
        }
