# edited by glg
import json
import sqlite3

from pypos.core.utils.db_helper import connect_sqlite_read_fast


class SettlementMutationLockService:
    """
    Shared validator lock settlement untuk mutation flow (return + pembatalan).
    """

    def __init__(self, db_path: str, log_warning=None):
        self.db_path = db_path
        self.log_warning = log_warning if callable(log_warning) else (lambda *_args, **_kwargs: None)
        self._settlement_txn_cache = set()
        self._settlement_txn_cache_last_max_id = 0
        self._settlement_txn_cache_ready = False
        self._settlement_map_table_checked = False
        self._settlement_map_table_available = False

    def _connect_read(self):
        conn = connect_sqlite_read_fast(self.db_path)
        conn.row_factory = sqlite3.Row
        return conn

    @staticmethod
    def _to_positive_int(value, default=0):
        try:
            parsed = int(value)
        except (TypeError, ValueError):
            parsed = int(default or 0)
        return parsed if parsed > 0 else int(default or 0)

    @classmethod
    def _normalize_transaksi_id(cls, transaksi_id):
        text = str(transaksi_id or "").strip()
        if not text:
            return 0
        return cls._to_positive_int(text, 0)

    # edited by glg
    @classmethod
    def _normalize_transaksi_ids(cls, transaksi_ids):
        normalized = []
        seen = set()
        for raw_id in list(transaksi_ids or []):
            parsed = cls._normalize_transaksi_id(raw_id)
            if parsed <= 0 or parsed in seen:
                continue
            seen.add(parsed)
            normalized.append(int(parsed))
        return normalized

    def _has_settlement_history_map_table(self, cursor):
        if bool(self._settlement_map_table_checked):
            return bool(self._settlement_map_table_available)
        try:
            cursor.execute(
                """
                SELECT name
                FROM sqlite_master
                WHERE type = 'table' AND name = 'settlement_history_transaksi_map'
                LIMIT 1
                """
            )
            self._settlement_map_table_available = cursor.fetchone() is not None
        except sqlite3.Error:
            self._settlement_map_table_available = False
        self._settlement_map_table_checked = True
        return bool(self._settlement_map_table_available)

    @classmethod
    def _extract_transaksi_ids_blob(cls, raw_blob):
        text = str(raw_blob or "").strip()
        if not text:
            return []
        try:
            payload = json.loads(text)
        except (TypeError, ValueError):
            return []
        if not isinstance(payload, list):
            return []
        out = []
        seen = set()
        for item in payload:
            parsed = cls._to_positive_int(item, 0)
            if parsed <= 0 or parsed in seen:
                continue
            seen.add(parsed)
            out.append(parsed)
        return out

    def _refresh_settlement_txn_cache(self, cursor):
        cursor.execute("SELECT COALESCE(MAX(id), 0) FROM settlement_history")
        row = cursor.fetchone()
        latest_max_id = int((row[0] if row else 0) or 0)
        if self._settlement_txn_cache_ready and latest_max_id == self._settlement_txn_cache_last_max_id:
            return
        if latest_max_id <= 0:
            self._settlement_txn_cache.clear()
            self._settlement_txn_cache_last_max_id = 0
            self._settlement_txn_cache_ready = True
            return

        cursor.execute(
            """
            SELECT id, data_transaksi_id
            FROM settlement_history
            WHERE COALESCE(data_transaksi_id, '') <> ''
            ORDER BY id DESC
            LIMIT 5000
            """
        )
        rows = cursor.fetchall() or []
        cached_ids = set()
        for row_item in rows:
            raw_blob = row_item[1] if row_item and len(row_item) > 1 else None
            for parsed in self._extract_transaksi_ids_blob(raw_blob):
                cached_ids.add(int(parsed))

        self._settlement_txn_cache = cached_ids
        self._settlement_txn_cache_last_max_id = latest_max_id
        self._settlement_txn_cache_ready = True

    def is_transaksi_locked(self, transaksi_id) -> bool:
        target_id = self._normalize_transaksi_id(transaksi_id)
        if target_id <= 0:
            return False
        return bool(self.get_lock_map([target_id]).get(int(target_id), False))

    # edited by glg
    def get_lock_map(self, transaksi_ids):
        normalized_ids = self._normalize_transaksi_ids(transaksi_ids)
        if not normalized_ids:
            return {}

        lock_map = {int(trx_id): False for trx_id in normalized_ids}
        unresolved_ids = set(int(trx_id) for trx_id in normalized_ids)
        conn = None
        try:
            conn = self._connect_read()
            cursor = conn.cursor()
            if self._has_settlement_history_map_table(cursor):
                placeholders = ",".join(["?"] * len(normalized_ids))
                cursor.execute(
                    f"""
                    SELECT DISTINCT transaksi_id
                    FROM settlement_history_transaksi_map
                    WHERE transaksi_id IN ({placeholders})
                    """,
                    normalized_ids,
                )
                for row in cursor.fetchall() or []:
                    try:
                        locked_id = int(row[0] if row else 0)
                    except (TypeError, ValueError):
                        locked_id = 0
                    if locked_id <= 0:
                        continue
                    if locked_id in lock_map:
                        lock_map[locked_id] = True
                    unresolved_ids.discard(locked_id)

            if unresolved_ids:
                self._refresh_settlement_txn_cache(cursor)
                for trx_id in list(unresolved_ids):
                    if trx_id in self._settlement_txn_cache:
                        lock_map[trx_id] = True
                        unresolved_ids.discard(trx_id)

            if unresolved_ids:
                # edited by glg
                # Fail-safe scan penuh sekali per batch agar terhindar N+1 query.
                cursor.execute(
                    """
                    SELECT data_transaksi_id
                    FROM settlement_history
                    WHERE COALESCE(data_transaksi_id, '') <> ''
                    ORDER BY id DESC
                    """
                )
                for row in cursor.fetchall() or []:
                    raw_blob = row[0] if row and len(row) > 0 else None
                    parsed_ids = self._extract_transaksi_ids_blob(raw_blob)
                    if not parsed_ids:
                        continue
                    for parsed_id in parsed_ids:
                        if int(parsed_id) not in unresolved_ids:
                            continue
                        lock_map[int(parsed_id)] = True
                        unresolved_ids.discard(int(parsed_id))
                    if not unresolved_ids:
                        break

            return lock_map
        except sqlite3.Error as exc:
            ids_text = ",".join(str(item) for item in normalized_ids[:20])
            self.log_warning(
                f"[SETTLEMENT_LOCK_DB_ERROR] Gagal validasi settlement lock batch ids={ids_text}: {exc}"
            )
            # Fail-safe: lock ketika validasi gagal.
            return {int(trx_id): True for trx_id in normalized_ids}
        except (TypeError, ValueError) as exc:
            ids_text = ",".join(str(item) for item in normalized_ids[:20])
            self.log_warning(
                f"[SETTLEMENT_LOCK_DATA_ERROR] Data settlement lock batch tidak valid ids={ids_text}: {exc}"
            )
            return {int(trx_id): True for trx_id in normalized_ids}
        finally:
            if conn is not None:
                conn.close()
