# tambahan untuk cek update ke server dengan api yang tersedia server
import sqlite3, os
from datetime import datetime, timedelta
from decimal import Decimal
from pypos.core.base_model import BaseModel
from pypos.core.utils.db_helper import connect_sqlite
from pypos.core.utils.path_utils import get_app_data_dir, get_db_path
from pypos.core.utils.config_utils import read_config, get_current_server_hash
from pypos.core.utils.sql_identifier_utils import (
    is_sql_identifier_valid,
    quote_sql_identifier,
)
from pypos.core.utils.sql_query_builder import (
    build_sql_with_identifier_in_clause,
    render_sql_template,
)
from pypos.core.database.schema_migrator import run_schema_migrations_once

# edited by glg
# Sentinel timestamp dipertahankan untuk kompatibilitas import lintas service.
SENTINEL_TS = "1990-01-01 23:59:59"

class SinkronModel(BaseModel):
    def __init__(self, db_path=get_db_path(), api_service=None):
        super().__init__()
        self.db_path = db_path
        self.api_service = api_service or self._build_default_api_service()

        self.sqlite_conn = connect_sqlite(self.db_path, timeout=30)
        self.sqlite_conn.row_factory = sqlite3.Row
        try:
            cur = self.sqlite_conn.cursor()
            cur.execute("PRAGMA journal_mode=WAL")
            cur.execute("PRAGMA synchronous=NORMAL")
            cur.execute("PRAGMA busy_timeout=5000")
            cur.execute("PRAGMA foreign_keys=ON")
            cur.close()
        except (sqlite3.Error, RuntimeError, OSError, ValueError, TypeError, AttributeError) as exc:
            self.log_warning(f"Gagal set PRAGMA koneksi SinkronModel: {exc}")

    def _build_default_api_service(self):
        module = __import__(
            "pypos.modules.sinkronisasi.services.sync_api_service",
            fromlist=["SyncApiService"],
        )
        return module.SyncApiService()

    def _quote_identifier(self, name: str) -> str:
        return quote_sql_identifier(name)

    def _sanitize_server_columns(self, cols):
        valid_cols = []
        seen = set()
        for col in cols or []:
            ident = str(col or "").strip()
            if not ident:
                continue
            if ident in seen:
                continue
            if not is_sql_identifier_valid(ident, strict=True):
                self.log_warning(f"Kolom server di-skip karena tidak valid: {ident}")
                continue
            seen.add(ident)
            valid_cols.append(ident)
        return valid_cols

    def _sync_log_path(self):
        base_dir = get_app_data_dir()
        log_dir = os.path.join(base_dir, "logs")
        os.makedirs(log_dir, exist_ok=True)
        return os.path.join(log_dir, "sync.log")

    def _log_sync(self, message, level="INFO", context=None):
        ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        ctx = ""
        if context:
            pairs = [f"{k}={v}" for k, v in context.items()]
            ctx = " | " + " ".join(pairs)
        line = f"[{ts}] {level} {message}{ctx}\n"
        try:
            with open(self._sync_log_path(), "a", encoding="utf-8") as f:
                f.write(line)
        except (OSError, ValueError, TypeError) as e:
            self.log_warning(f"Gagal tulis sync log: {e}")

    def _normalize_cabang_id(self, value):
        try:
            if value is None:
                return None
            if isinstance(value, str) and value.strip() == "":
                return None
            return int(value)
        except (TypeError, ValueError):
            return None

    def _get_sync_cabang_settings(self):
        cfg = read_config()
        raw_ids = cfg.get("sync_cabang_global_ids", [0, -1])
        allow_null = int(cfg.get("sync_cabang_allow_null_global", 1)) == 1
        global_ids = set()
        if isinstance(raw_ids, (list, tuple, set)):
            for v in raw_ids:
                n = self._normalize_cabang_id(v)
                if n is not None:
                    global_ids.add(n)
        return global_ids, allow_null

    def _filter_rows_by_cabang(self, rows, cabang_id):
        target_id = self._normalize_cabang_id(cabang_id)
        if target_id is None:
            return rows, 0
        global_ids, allow_null = self._get_sync_cabang_settings()
        allowed_ids = set(global_ids)
        allowed_ids.add(target_id)
        filtered = []
        skipped = 0
        for row in rows or []:
            if not isinstance(row, dict):
                filtered.append(row)
                continue
            if "cabang_id" not in row:
                filtered.append(row)
                continue
            row_id = self._normalize_cabang_id(row.get("cabang_id"))
            if row_id is None:
                if allow_null:
                    filtered.append(row)
                else:
                    skipped += 1
                continue
            if row_id in allowed_ids:
                filtered.append(row)
            else:
                skipped += 1
        return filtered, skipped

    def close_connections(self):
        if self.sqlite_conn:
            self.sqlite_conn.close()

    def get_last_sync_info(self):
        """Ambil info last_update & lastID dari SQLite tracking table"""
        run_schema_migrations_once(self.db_path, strict=False)
        cur = self.sqlite_conn.cursor()
        cur.execute("SELECT * FROM sync_tracking")
        data = {row["tabel"]: dict(row) for row in cur.fetchall()}
        return data

    def _get_local_max_datetime(self, table_name, column_name):
        try:
            cur = self.sqlite_conn.cursor()
            table_sql = self._quote_identifier(table_name)
            column_sql = self._quote_identifier(column_name)
            query = render_sql_template(
                "SELECT MAX({column_sql}) FROM {table_sql}",
                column_sql=column_sql,
                table_sql=table_sql,
            )
            cur.execute(query)
            row = cur.fetchone()
            return row[0] if row and row[0] else None
        except (sqlite3.Error, TypeError, ValueError):
            return None

    def save_last_sync_info(self, table_name, last_update, last_id):
        cur = self.sqlite_conn.cursor()
        cur.execute("""
            INSERT INTO sync_tracking (tabel, last_update, last_id)
            VALUES (?, ?, ?)
            ON CONFLICT(tabel) DO UPDATE SET
                last_update=excluded.last_update,
                last_id=excluded.last_id
        """, (table_name, last_update, last_id))
        self.sqlite_conn.commit()

    def _resolve_cabang_id(self, machine_id, cabang_id):

        if cabang_id in (None, "", 0, "0", "none", "None", "-1"):
            try:
                cur = self.sqlite_conn.cursor()
                self._ensure_device_server_hash_column(cur)
                server_hash = get_current_server_hash()
                cur.execute(
                    """
                    SELECT cabang_id FROM per_cabang_device
                    WHERE machine_id = ? AND status = 1 AND trash = 0 AND server_hash = ?
                    ORDER BY last_update DESC LIMIT 1
                    """,
                    (machine_id, server_hash)
                )
                row = cur.fetchone()
                if row and row[0]:
                    return row[0]
            except (sqlite3.Error, TypeError, ValueError, RuntimeError, OSError):
                return cabang_id
        return cabang_id
    def _ensure_device_server_hash_column(self, cursor):
        try:
            cursor.execute("PRAGMA table_info(per_cabang_device)")
            cols = [row[1] for row in cursor.fetchall()]
            if "server_hash" not in cols:
                cursor.execute("ALTER TABLE per_cabang_device ADD COLUMN server_hash TEXT")
                self.sqlite_conn.commit()
        except (sqlite3.Error, RuntimeError, OSError, ValueError, TypeError) as exc:
            self.log_debug(f"Skip ensure kolom server_hash per_cabang_device: {exc}")

    def check_update_server(
        self,
        machine_id,
        cabang_id,
        tables=None,
        force_full_tables=None,
    ):
        if not tables:
            tables = [
                "produk", "diskon", "per_cabang", "per_cabang_device",
                "per_customers", "per_employee", "company_profile"
            ]
        return self.api_service.check_update_server(
            model=self,
            machine_id=machine_id,
            cabang_id=cabang_id,
            tables=tables,
            force_full_tables=force_full_tables,
        )

    def sync_data_server(self, machine_id, cabang_id, tables, force_full_tables=None):
        return self.api_service.sync_data_server(
            model=self,
            machine_id=machine_id,
            cabang_id=cabang_id,
            tables=tables,
            force_full_tables=force_full_tables,
        )

    def sync_data_server_table(self, machine_id, cabang_id, table_name, dtime, last_id=0, partial=0, page_size=500):
        return self.api_service.sync_data_server_table(
            model=self,
            machine_id=machine_id,
            cabang_id=cabang_id,
            table_name=table_name,
            dtime=dtime,
            last_id=last_id,
            partial=partial,
            page_size=page_size,
        )

    def apply_sync_result(self, table_name, data_new, full_refresh=False, cabang_id=None):
        """Masukkan data API ke SQLite lokal"""
        table_sql = self._quote_identifier(table_name)
        # edited by glg
        # Proteksi atomic untuk full refresh per_employee agar tidak pernah meninggalkan tabel kosong
        # jika sinkronisasi dibatalkan/terputus di tengah proses.
        per_employee_atomic_refresh = bool(full_refresh and table_name == "per_employee")
        savepoint_active = False

        def _rollback_employee_savepoint():
            nonlocal savepoint_active
            if not savepoint_active:
                return
            try:
                self.sqlite_conn.execute("ROLLBACK TO SAVEPOINT sp_per_employee_full_refresh")
            except (sqlite3.Error, RuntimeError, OSError, ValueError, TypeError) as exc:
                self.log_debug(f"Rollback savepoint per_employee diabaikan: {exc}")
            try:
                self.sqlite_conn.execute("RELEASE SAVEPOINT sp_per_employee_full_refresh")
            except (sqlite3.Error, RuntimeError, OSError, ValueError, TypeError) as exc:
                self.log_debug(f"Release savepoint rollback per_employee diabaikan: {exc}")
            savepoint_active = False

        def _release_employee_savepoint():
            nonlocal savepoint_active
            if not savepoint_active:
                return
            try:
                self.sqlite_conn.execute("RELEASE SAVEPOINT sp_per_employee_full_refresh")
            except (sqlite3.Error, RuntimeError, OSError, ValueError, TypeError) as exc:
                self.log_debug(f"Release savepoint per_employee diabaikan: {exc}")
            savepoint_active = False

        if per_employee_atomic_refresh:
            try:
                self.sqlite_conn.execute("SAVEPOINT sp_per_employee_full_refresh")
                savepoint_active = True
            except (sqlite3.Error, RuntimeError, OSError, ValueError, TypeError) as exc:
                raise RuntimeError(f"Gagal memulai transaksi atomic per_employee: {exc}") from exc

        self._log_sync(
            "apply_sync_result.start",
            context={
                "table": table_name,
                "rows": len(data_new or []),
                "full_refresh": int(bool(full_refresh)),
                "cabang_id": cabang_id if cabang_id is not None else "",
            },
        )
        data_new, skipped = self._filter_rows_by_cabang(data_new, cabang_id)
        if skipped:
            self._log_sync(
                "apply_sync_result.filtered",
                context={
                    "table": table_name,
                    "skipped": skipped,
                    "cabang_id": cabang_id if cabang_id is not None else "",
                },
            )
        if full_refresh:
            try:
                cur_clear = self.sqlite_conn.cursor()
                clear_query = render_sql_template("DELETE FROM {table_sql}", table_sql=table_sql)
                cur_clear.execute(clear_query)
                # edited by glg
                # per_employee full refresh ditahan commit-nya sampai akhir (atomic swap behavior).
                if not per_employee_atomic_refresh:
                    self.sqlite_conn.commit()
            except (sqlite3.Error, RuntimeError, OSError, ValueError, TypeError) as e:
                _rollback_employee_savepoint()
                self.log_warning(f"Gagal clear tabel {table_name}: {e}")

        if not data_new:
            self.log_debug(f"Table {table_name}: No new data")
            self._log_sync(
                "apply_sync_result.empty",
                context={
                    "table": table_name,
                    "full_refresh": int(bool(full_refresh)),
                },
            )
            # edited by glg
            # Jika payload per_employee kosong saat full refresh,
            # rollback agar data akun lokal sebelumnya tetap aman.
            _rollback_employee_savepoint()
            return 0
        self.log_debug(f"apply_sync_result untuk {table_name}: {len(data_new)} rows")
        cur = self.sqlite_conn.cursor()

        # --- cek tabel lokal ---
        cur.execute(
            "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
            (table_name,),
        )
        table_exists = cur.fetchone() is not None
        # --- cek kolom lokal ---
        id_is_pk = False
        if table_exists:
            cur.execute(f"PRAGMA table_info({table_sql})")
            info_rows = cur.fetchall()
            local_cols = [row[1] for row in info_rows]
            id_is_pk = any(str(row[1]) == "id" and int(row[5] or 0) > 0 for row in info_rows)
        else:
            local_cols = []

        server_cols = []
        try:
            all_cols = set()
            for r in data_new:
                if isinstance(r, dict):
                    all_cols.update(r.keys())
            server_cols = list(all_cols) if all_cols else list(data_new[0].keys())
        except (TypeError, ValueError, KeyError, AttributeError, IndexError):
            server_cols = list(data_new[0].keys())
        server_cols = self._sanitize_server_columns(server_cols)
        if any(isinstance(r, dict) and ("id" in r) for r in (data_new or [])) and "id" not in server_cols:
            server_cols.insert(0, "id")
        if not server_cols:
            self.log_warning(f"Tidak ada kolom server valid untuk tabel {table_name}.")
            return 0

        # --- buat tabel jika belum ada ---
        if not table_exists:
            cols_defs = []
            for col in server_cols:
                col_sql = self._quote_identifier(col)
                if col == "id":
                    cols_defs.append(f"{col_sql} TEXT PRIMARY KEY")
                else:
                    cols_defs.append(f"{col_sql} TEXT")
            cols_sql = ", ".join(cols_defs) if cols_defs else '"id" TEXT PRIMARY KEY'
            cur.execute(f"CREATE TABLE IF NOT EXISTS {table_sql} ({cols_sql})")
            local_cols = [c.replace('"', '').split()[0] for c in cols_defs]
            id_is_pk = True

        # --- tambahkan kolom baru kalau belum ada ---
        for col in server_cols:
            if col not in local_cols:
                self.log_info(f"Menambahkan kolom baru '{col}' ke tabel '{table_name}'")
                cur.execute(f"ALTER TABLE {table_sql} ADD COLUMN {self._quote_identifier(col)} TEXT")

        if table_name == "per_employee" and "last_dtime" in local_cols and "last_dtime" not in server_cols:
            server_cols.append("last_dtime")

        # --- Tentukan matching strategy berdasarkan tabel ---
        updated_count = 0
        inserted_count = 0
        last_update_val = None
        last_id_val = None

        # Strategy: gunakan composite key untuk tabel price, gunakan id untuk tabel lain
        use_composite_key = (table_name == "price")
        supports_fast_id_upsert = bool((not use_composite_key) and id_is_pk and ("id" in server_cols))

        valid_rows = []
        for row in data_new or []:
            if not isinstance(row, dict):
                continue
            if table_name == "per_employee" and "last_dtime" in local_cols:
                if not row.get("last_dtime"):
                    row["last_dtime"] = row.get("last_update") or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            row_id = row.get("id")
            if not row_id:
                self.log_warning(f"Row tanpa ID, skip: {row}")
                continue
            valid_rows.append(row)
            if "last_update" in row:
                last_update_val = row["last_update"]
            elif "last_dtime" in row:
                last_update_val = row["last_dtime"]
            try:
                parsed_last_id = int(str(row["id"]).strip())
            except (TypeError, ValueError, AttributeError):
                parsed_last_id = None
            if parsed_last_id is not None:
                last_id_val = parsed_last_id

        if not valid_rows:
            _rollback_employee_savepoint()
            return 0

        def _apply_legacy_id_strategy(rows):
            nonlocal updated_count, inserted_count
            set_parts = [f"{self._quote_identifier(col)} = ?" for col in server_cols]
            set_clause = ", ".join(set_parts)
            col_names = ",".join([self._quote_identifier(c) for c in server_cols])
            placeholders = ",".join(["?" for _ in server_cols])
            for row in rows:
                row_id = row.get("id")
                query_existing = render_sql_template(
                    "SELECT id FROM {table_sql} WHERE id = ?",
                    table_sql=table_sql,
                )
                cur.execute(query_existing, (row_id,))
                existing = cur.fetchone()
                values = [row.get(c) for c in server_cols]
                if existing:
                    update_sql = render_sql_template(
                        """
                        UPDATE {table_sql}
                        SET {set_clause}
                        WHERE id = ?
                        """,
                        table_sql=table_sql,
                        set_clause=set_clause,
                    )
                    cur.execute(
                        update_sql,
                        values + [row_id],
                    )
                    updated_count += 1
                else:
                    insert_sql = render_sql_template(
                        """
                        INSERT INTO {table_sql} ({col_names})
                        VALUES ({placeholders})
                        """,
                        table_sql=table_sql,
                        col_names=col_names,
                        placeholders=placeholders,
                    )
                    cur.execute(
                        insert_sql,
                        values,
                    )
                    inserted_count += 1

        try:
            if use_composite_key:
                set_parts = [f"{self._quote_identifier(col)} = ?" for col in server_cols]
                set_clause = ", ".join(set_parts)
                col_names = ",".join([self._quote_identifier(c) for c in server_cols])
                placeholders = ",".join(["?" for _ in server_cols])
                for row in valid_rows:
                    # Untuk tabel PRICE: Match berdasarkan produk_id + cabang_id + jenis_value
                    produk_id = row.get("produk_id")
                    cabang_id = row.get("cabang_id")
                    jenis_value = row.get("jenis_value")
                    if not all([produk_id, cabang_id is not None, jenis_value]):
                        self.log_warning(f"Price row tanpa composite key lengkap, skip: {row}")
                        continue
                    price_exists_sql = render_sql_template(
                        """
                        SELECT id FROM {table_sql}
                        WHERE produk_id = ? AND cabang_id = ? AND jenis_value = ?
                        """,
                        table_sql=table_sql,
                    )
                    cur.execute(
                        price_exists_sql,
                        (produk_id, cabang_id, jenis_value),
                    )
                    existing = cur.fetchone()
                    values = [row.get(c) for c in server_cols]
                    if existing:
                        price_update_sql = render_sql_template(
                            """
                            UPDATE {table_sql}
                            SET {set_clause}
                            WHERE produk_id = ? AND cabang_id = ? AND jenis_value = ?
                            """,
                            table_sql=table_sql,
                            set_clause=set_clause,
                        )
                        cur.execute(
                            price_update_sql,
                            values + [produk_id, cabang_id, jenis_value],
                        )
                        updated_count += 1
                    else:
                        price_insert_sql = render_sql_template(
                            """
                            INSERT INTO {table_sql} ({col_names})
                            VALUES ({placeholders})
                            """,
                            table_sql=table_sql,
                            col_names=col_names,
                            placeholders=placeholders,
                        )
                        cur.execute(
                            price_insert_sql,
                            values,
                        )
                        inserted_count += 1
            elif supports_fast_id_upsert:
                dedup_map = {}
                for row in valid_rows:
                    dedup_map[str(row.get("id"))] = row
                id_rows = list(dedup_map.values())

                existing_ids = set()
                row_id_keys = [str(row.get("id")) for row in id_rows if row.get("id")]
                chunk_size = 300
                for offset in range(0, len(row_id_keys), chunk_size):
                    chunk = row_id_keys[offset: offset + chunk_size]
                    if not chunk:
                        continue
                    query_existing_ids, query_existing_params = build_sql_with_identifier_in_clause(
                        render_sql_template("SELECT id FROM {table_sql} WHERE", table_sql=table_sql),
                        "id",
                        chunk,
                        unique=True,
                    )
                    cur.execute(query_existing_ids, query_existing_params)
                    for existing in cur.fetchall() or []:
                        existing_ids.add(str(existing[0]))

                col_names = ",".join([self._quote_identifier(c) for c in server_cols])
                value_marks = ",".join(["?" for _ in server_cols])
                update_cols = [c for c in server_cols if c != "id"]
                if update_cols:
                    update_set_clause = ", ".join(
                        [
                            f"{self._quote_identifier(col)} = excluded.{self._quote_identifier(col)}"
                            for col in update_cols
                        ]
                    )
                    upsert_sql = render_sql_template(
                        "INSERT INTO {table_sql} ({col_names}) VALUES ({value_marks}) "
                        "ON CONFLICT({id_col}) DO UPDATE SET {update_set_clause}",
                        table_sql=table_sql,
                        col_names=col_names,
                        value_marks=value_marks,
                        id_col=self._quote_identifier("id"),
                        update_set_clause=update_set_clause,
                    )
                else:
                    upsert_sql = render_sql_template(
                        "INSERT INTO {table_sql} ({col_names}) VALUES ({value_marks}) "
                        "ON CONFLICT({id_col}) DO NOTHING",
                        table_sql=table_sql,
                        col_names=col_names,
                        value_marks=value_marks,
                        id_col=self._quote_identifier("id"),
                    )

                value_batch = []
                for row in id_rows:
                    row_id_key = str(row.get("id"))
                    if row_id_key in existing_ids:
                        updated_count += 1
                    else:
                        inserted_count += 1
                    value_batch.append([row.get(c) for c in server_cols])

                try:
                    cur.executemany(upsert_sql, value_batch)
                except sqlite3.OperationalError as exc:
                    self.log_warning(
                        f"Fast upsert {table_name} gagal, fallback legacy strategy: {exc}"
                    )
                    updated_count = 0
                    inserted_count = 0
                    _apply_legacy_id_strategy(id_rows)
            else:
                _apply_legacy_id_strategy(valid_rows)
        except (sqlite3.Error, TypeError, ValueError, KeyError, AttributeError, RuntimeError, OSError):
            _rollback_employee_savepoint()
            raise

        # Update tracking dengan data terakhir
        if last_update_val and last_id_val:
            self.save_last_sync_info(table_name, last_update_val, last_id_val)

        _release_employee_savepoint()
        self.sqlite_conn.commit()
        self.log_debug(f"Commit selesai untuk {table_name}: {updated_count} updated, {inserted_count} inserted")
        self._log_sync(
            "apply_sync_result.done",
            context={
                "table": table_name,
                "updated": updated_count,
                "inserted": inserted_count,
                "total": updated_count + inserted_count,
                "last_update": last_update_val or "",
                "last_id": last_id_val or "",
            },
        )
        return updated_count + inserted_count
