# edited by glg
import pytest

from pypos.modules.platform_ops.services.fleet_rollout_orchestrator_service import (
    FleetRolloutOrchestratorError,
    FleetRolloutOrchestratorService,
)


pytestmark = [pytest.mark.unit]


def test_rollout_orchestrator_continue_for_healthy_canary():
    service = FleetRolloutOrchestratorService()
    plan = service.build_plan(
        branch_ids=[1, 2, 3, 4, 5],
        config_version="cfg-2026.04.15",
        canary_percent=20,
        wave_percents=[50, 100],
    )

    result = service.evaluate_wave(
        plan=plan,
        wave_name="canary",
        metrics_payload={"error_rate_pct": 1.1, "p95_latency_ms": 1200},
        fail_threshold_pct=2.0,
        latency_threshold_ms=2000.0,
    )

    assert result["decision"] == "continue"
    assert result["healthy"] is True
    assert result["next_action"] == "continue_to_next_wave"
    assert result["wave"]["name"] == "canary"


def test_rollout_orchestrator_halt_when_wave_unhealthy():
    service = FleetRolloutOrchestratorService()
    plan = service.build_plan(
        branch_ids=[1, 2, 3, 4, 5, 6],
        config_version="cfg-2026.04.15",
        canary_percent=10,
        wave_percents=[50, 100],
    )

    result = service.evaluate_wave(
        plan=plan,
        wave_name="wave_50",
        metrics_payload=[
            {"branch_id": 1, "error_rate_pct": 3.5, "p95_latency_ms": 1900},
            {"branch_id": 2, "error_rate_pct": 1.5, "p95_latency_ms": 2600},
        ],
        fail_threshold_pct=2.0,
        latency_threshold_ms=2000.0,
    )

    assert result["decision"] == "halt_and_rollback"
    assert result["healthy"] is False
    assert result["next_action"] == "halt_and_rollback"
    assert result["wave"]["name"] == "wave_50"


def test_rollout_orchestrator_raise_when_wave_not_found():
    service = FleetRolloutOrchestratorService()
    plan = service.build_plan(
        branch_ids=[1, 2, 3],
        config_version="cfg-2026.04.15",
        canary_percent=10,
        wave_percents=[100],
    )

    with pytest.raises(FleetRolloutOrchestratorError):
        service.evaluate_wave(
            plan=plan,
            wave_name="wave_999",
            metrics_payload={"error_rate_pct": 1.0, "p95_latency_ms": 1000},
        )
