import sqlite3
from pathlib import Path

import pytest

from pypos.modules.platform_ops.services.dr_backup_policy_service import DrBackupPolicyService
from pypos.modules.platform_ops.services.dr_restore_drill_service import DrRestoreDrillService

# edited by glg
pytestmark = [pytest.mark.unit]


def _create_sqlite_db(path: Path):
    conn = sqlite3.connect(str(path))
    cur = conn.cursor()
    cur.execute("CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY, nama TEXT)")
    cur.execute("INSERT INTO t (nama) VALUES (?)", ("uji",))
    conn.commit()
    conn.close()


def test_run_restore_drill_ok_dengan_sqlite_valid(tmp_path):
    source_db = tmp_path / "state.db"
    _create_sqlite_db(source_db)

    backup_root = tmp_path / "backups"
    backup_service = DrBackupPolicyService(str(backup_root))
    backup = backup_service.run_backup(source_paths=[str(source_db)], tag="drill")

    service = DrRestoreDrillService()
    result = service.run_restore_drill(
        manifest_path=str(backup["manifest_path"]),
        restore_root=str(tmp_path / "restore"),
    )

    assert result["ok"] is True
    assert int(result["checked_files"]) == 1
    assert result["missing_files"] == []
    assert result["checksum_mismatches"] == []
    assert len(result["sqlite_smoke"]) == 1
    assert bool(result["sqlite_smoke"][0]["ok"]) is True
    assert Path(result["report_path"]).exists()


def test_run_restore_drill_deteksi_checksum_mismatch(tmp_path):
    source_db = tmp_path / "state2.db"
    _create_sqlite_db(source_db)

    backup_service = DrBackupPolicyService(str(tmp_path / "backups"))
    backup = backup_service.run_backup(source_paths=[str(source_db)], tag="tamper")
    copied = backup["copied_files"][0]
    backup_copy = Path(copied["target"])
    backup_copy.write_bytes(backup_copy.read_bytes() + b"tampered")

    service = DrRestoreDrillService()
    result = service.run_restore_drill(
        manifest_path=str(backup["manifest_path"]),
        restore_root=str(tmp_path / "restore"),
    )

    assert result["ok"] is False
    assert len(result["checksum_mismatches"]) >= 1

