# edited by glg
import sqlite3
import tempfile
import unittest
from pathlib import Path

import pypos.modules.sinkronisasi.services.transaction_export_service as export_mod


def _seed_atomic_schema(db_path: str):
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE export_flux (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            batch_start TEXT,
            batch_end TEXT,
            table_name TEXT,
            server_hash TEXT,
            file_seq INTEGER,
            row_count INTEGER,
            status TEXT,
            file_path TEXT,
            error_log TEXT,
            created_at TEXT,
            updated_at TEXT,
            file_hash TEXT,
            file_size INTEGER
        )
        """
    )
    cur.execute(
        """
        CREATE TABLE export_cursor_table_scoped (
            table_name TEXT NOT NULL,
            server_hash TEXT NOT NULL,
            last_id INTEGER NOT NULL DEFAULT 0,
            PRIMARY KEY(table_name, server_hash)
        )
        """
    )
    cur.execute(
        """
        CREATE TABLE export_retry_state (
            table_name TEXT NOT NULL,
            server_hash TEXT NOT NULL,
            attempt_count INTEGER NOT NULL DEFAULT 0,
            next_retry_at TEXT NULL,
            error_code TEXT NULL,
            last_error TEXT NULL,
            updated_at TEXT NOT NULL,
            PRIMARY KEY (table_name, server_hash)
        )
        """
    )
    cur.execute(
        """
        INSERT INTO export_flux (
            batch_start, batch_end, table_name, server_hash, file_seq, row_count, status, file_path, error_log, created_at, updated_at
        ) VALUES ('0-10', '2026-03-01 10:00:00', 'transaksi', 'srv', NULL, 0, 'PENDING', NULL, '', datetime('now'), datetime('now'))
        """
    )
    cur.execute(
        """
        INSERT INTO export_retry_state (table_name, server_hash, attempt_count, next_retry_at, error_code, last_error, updated_at)
        VALUES ('transaksi', 'srv', 3, datetime('now', '+5 minutes'), 'EXPORT_ERROR', 'old_error', datetime('now'))
        """
    )
    conn.commit()
    conn.close()


class TransactionExportServiceAtomicTest(unittest.TestCase):
    def _patch_runtime(self, db_path: str):
        original_get_db_path = export_mod.get_db_path
        original_read_config = export_mod.read_config
        export_mod.get_db_path = lambda: db_path
        export_mod.read_config = lambda: {}
        return original_get_db_path, original_read_config

    def _restore_runtime(self, originals):
        export_mod.get_db_path, export_mod.read_config = originals

    def test_finalize_success_updates_flux_cursor_and_clears_retry(self):
        with tempfile.TemporaryDirectory() as td:
            db_path = str(Path(td) / "atomic_ok.db")
            _seed_atomic_schema(db_path)

            originals = self._patch_runtime(db_path)
            try:
                service = export_mod.TransactionExportService()
                service._finalize_flux_success(
                    table_name="transaksi",
                    server_hash="srv",
                    flux_id=1,
                    row_count=10,
                    file_path="C:/tmp/f_1.xz",
                    file_seq=1,
                    file_hash="a" * 64,
                    file_size=128,
                    new_last_id=10,
                )
            finally:
                self._restore_runtime(originals)

            conn = sqlite3.connect(db_path)
            cur = conn.cursor()
            cur.execute("SELECT status, row_count, file_path, file_seq FROM export_flux WHERE id = 1")
            flux = cur.fetchone()
            cur.execute(
                "SELECT last_id FROM export_cursor_table_scoped WHERE table_name = 'transaksi' AND server_hash = 'srv'"
            )
            cursor_row = cur.fetchone()
            cur.execute(
                "SELECT COUNT(1) FROM export_retry_state WHERE table_name = 'transaksi' AND server_hash = 'srv'"
            )
            retry_count = int(cur.fetchone()[0] or 0)
            conn.close()

            self.assertEqual(flux[0], "SUCCESS")
            self.assertEqual(int(flux[1]), 10)
            self.assertEqual(flux[2], "C:/tmp/f_1.xz")
            self.assertEqual(int(flux[3]), 1)
            self.assertEqual(int(cursor_row[0]), 10)
            self.assertEqual(retry_count, 0)

    def test_finalize_success_rolls_back_when_cursor_value_invalid(self):
        with tempfile.TemporaryDirectory() as td:
            db_path = str(Path(td) / "atomic_rollback.db")
            _seed_atomic_schema(db_path)

            originals = self._patch_runtime(db_path)
            try:
                service = export_mod.TransactionExportService()
                with self.assertRaises(ValueError):
                    service._finalize_flux_success(
                        table_name="transaksi",
                        server_hash="srv",
                        flux_id=1,
                        row_count=10,
                        file_path="C:/tmp/f_1.xz",
                        file_seq=1,
                        file_hash="a" * 64,
                        file_size=128,
                        new_last_id="invalid-int",
                    )
            finally:
                self._restore_runtime(originals)

            conn = sqlite3.connect(db_path)
            cur = conn.cursor()
            cur.execute("SELECT status, row_count, file_path, file_seq FROM export_flux WHERE id = 1")
            flux = cur.fetchone()
            cur.execute(
                "SELECT COUNT(1) FROM export_cursor_table_scoped WHERE table_name = 'transaksi' AND server_hash = 'srv'"
            )
            cursor_count = int(cur.fetchone()[0] or 0)
            cur.execute(
                "SELECT COUNT(1) FROM export_retry_state WHERE table_name = 'transaksi' AND server_hash = 'srv'"
            )
            retry_count = int(cur.fetchone()[0] or 0)
            conn.close()

            self.assertEqual(flux[0], "PENDING")
            self.assertEqual(int(flux[1]), 0)
            self.assertIsNone(flux[2])
            self.assertIsNone(flux[3])
            self.assertEqual(cursor_count, 0)
            self.assertEqual(retry_count, 1)


if __name__ == "__main__":
    unittest.main()
