# edited by glg
import logging
import sqlite3
import time
import uuid
from typing import Callable, Dict, List, Optional

from pypos.core.utils.config_utils import get_current_server_hash, read_config
from pypos.modules.penjualan.services.history_cleanup_service import HistoryCleanupService
from pypos.modules.sinkronisasi.config import get_export_tables
from pypos.modules.sinkronisasi.services.transaction_export_errors import (
    TransactionExportRetryableError,
)
from pypos.modules.sinkronisasi.services.transaction_export_table_iteration_use_case_service import (
    TransactionExportTableIterationUseCaseService,
)


class TransactionExportBatchUseCaseService:
    """
    Use-case service untuk memecah hotspot export_batch dari TransactionExportService.
    Fokus:
    - orkestrasi batch export lintas tabel
    - observability konsisten (trace_id + metrik p95 per tabel)
    """

    def __init__(self, export_service, logger=None):
        self._svc = export_service
        self._logger = logger or logging.getLogger(__name__)
        self._table_iteration_use_case = TransactionExportTableIterationUseCaseService(
            export_service=export_service,
            logger=self._logger,
        )

    @staticmethod
    def _new_trace_id() -> str:
        return f"exp-{uuid.uuid4().hex[:12]}"

    @staticmethod
    def _safe_p95(values: List[float]) -> float:
        if not values:
            return 0.0
        ordered = sorted(float(item) for item in values)
        index = int(round(0.95 * (len(ordered) - 1)))
        index = max(0, min(index, len(ordered) - 1))
        return float(ordered[index])

    def _log_perf_summary(
        self,
        *,
        trace_id: str,
        total_rows: int,
        total_tables: int,
        total_elapsed_ms: float,
        table_elapsed_ms: List[float],
    ) -> None:
        p95_table_ms = self._safe_p95(table_elapsed_ms)
        self._logger.info(
            "[EXPORT_BATCH_METRICS] trace_id=%s total_rows=%s total_tables=%s total_ms=%.1f p95_table_ms=%.1f",
            str(trace_id or "-"),
            int(total_rows or 0),
            int(total_tables or 0),
            float(total_elapsed_ms or 0.0),
            float(p95_table_ms or 0.0),
        )

    # edited by glg
    def run_export_batch(self, progress_callback: Optional[Callable[[int], None]] = None) -> int:
        svc = self._svc
        trace_id = self._new_trace_id()
        started_at = time.perf_counter()
        table_elapsed_samples: List[float] = []

        svc._ensure_tables()
        batch_end = svc._now_db()
        config = read_config()
        batch_limit = int(config.get("export_batch_limit") or 0)
        total_rows = 0
        svc._emit_progress(progress_callback, 0)

        configured_tables = get_export_tables()
        export_tables = svc._get_exportable_tables(configured_tables)
        if not export_tables:
            self._logger.warning("[DEBUG EXPORT][%s] tidak ada tabel valid untuk export", trace_id)
            svc._emit_progress(progress_callback, 100)
            self._log_perf_summary(
                trace_id=trace_id,
                total_rows=0,
                total_tables=0,
                total_elapsed_ms=(time.perf_counter() - started_at) * 1000.0,
                table_elapsed_ms=[],
            )
            return 0

        max_ids = svc._get_max_ids(tuple(export_tables))
        transaksi_ids: List[int] = []
        transaksi_rows_for_registry: List[Dict] = []
        master_index: Dict[str, Dict] = {}

        self._logger.debug(
            "[DEBUG EXPORT][%s] start batch_end=%s limit=%s",
            trace_id,
            batch_end,
            batch_limit,
        )
        server_hash = get_current_server_hash()
        total_tables = max(1, int(len(export_tables) or 0))
        for table_index, table_name in enumerate(export_tables, start=1):
            table_started_at = time.perf_counter()
            svc._emit_progress(
                progress_callback,
                int(((table_index - 1) * 100) / total_tables),
            )
            try:
                iteration_result = self._table_iteration_use_case.run_table_iteration(
                    trace_id=trace_id,
                    table_name=table_name,
                    batch_end=batch_end,
                    batch_limit=batch_limit,
                    config=config,
                    max_id=max_ids.get(table_name, 0),
                    server_hash=server_hash,
                    transaksi_ids=transaksi_ids,
                    transaksi_rows_for_registry=transaksi_rows_for_registry,
                    master_index=master_index,
                )
                total_rows += int(iteration_result.get("row_count") or 0)
                transaksi_ids = list(iteration_result.get("transaksi_ids") or transaksi_ids)
                transaksi_rows_for_registry = list(
                    iteration_result.get("transaksi_rows_for_registry") or transaksi_rows_for_registry
                )
                master_index = dict(iteration_result.get("master_index") or master_index)
            except TransactionExportRetryableError as export_exc:
                raw_error = export_exc.raw_error
                built_error_code = svc._build_error_code(raw_error) if raw_error is not None else ""
                error_code = str(export_exc.error_code or "").strip() or built_error_code or "EXPORT_TABLE_RETRY"
                retry_meta = svc._record_retry_state(
                    table_name,
                    server_hash,
                    export_exc.error_message,
                    error_code,
                )
                if int(export_exc.flux_id or 0) > 0:
                    svc._mark_flux_retry(
                        int(export_exc.flux_id),
                        export_exc.error_message,
                        attempt_count=retry_meta.get("attempt_count"),
                        next_retry_at=retry_meta.get("next_retry_at"),
                        error_code=retry_meta.get("error_code"),
                    )
                self._logger.warning(
                    "[DEBUG EXPORT][%s] table=%s error=%s",
                    trace_id,
                    table_name,
                    export_exc.error_message,
                )
            table_elapsed_samples.append((time.perf_counter() - table_started_at) * 1000.0)
            svc._emit_progress(
                progress_callback,
                int((table_index * 100) / total_tables),
            )

        self._logger.debug("[DEBUG EXPORT][%s] done total_rows=%s", trace_id, total_rows)
        try:
            cleanup_service = HistoryCleanupService()
            deleted = cleanup_service.cleanup_if_due()
            if deleted:
                self._logger.debug("[DEBUG EXPORT][%s] cleanup deleted=%s", trace_id, deleted)
        except (RuntimeError, sqlite3.Error) as e:
            self._logger.warning("[DEBUG EXPORT][%s] cleanup error=%s", trace_id, e)
        svc._emit_progress(progress_callback, 100)

        total_elapsed_ms = (time.perf_counter() - started_at) * 1000.0
        self._log_perf_summary(
            trace_id=trace_id,
            total_rows=total_rows,
            total_tables=len(export_tables),
            total_elapsed_ms=total_elapsed_ms,
            table_elapsed_ms=table_elapsed_samples,
        )
        return total_rows
