﻿from datetime import datetime, time, timedelta
import sqlite3
import re
from typing import Dict, List, Optional

from pypos.core.base_model import BaseModel
from pypos.core.utils.db_helper import connect_sqlite
from pypos.core.utils.device_utils import get_active_device_info, get_device_id


class DiskonCustomerModel(BaseModel):
    def __init__(self, db_path: str):
        super().__init__()
        self.db_path = db_path

    # edited by glg
    def _log_reason(self, reason_code: str, exc):
        self.log_debug(f"[{str(reason_code or 'DISKON_REASON_UNKNOWN')}] {exc}")

    def get_customer_level_id(self, customer_id: int) -> int:
        if not customer_id:
            return 0
        conn = connect_sqlite(self.db_path)
        conn.row_factory = lambda cur, row: row[0]
        cursor = conn.cursor()
        try:
            cursor.execute(
                "SELECT level_id FROM per_customers WHERE id = ? LIMIT 1",
                (customer_id,),
            )
            level_id = cursor.fetchone()
            return int(level_id or 0)
        except (sqlite3.Error, TypeError, ValueError) as exc:
            self._log_reason("DISKON_CUSTOMER_LEVEL_FALLBACK", exc)
            return 0
        finally:
            cursor.close()
            conn.close()

    def get_active_cabang_id(self) -> int:
        try:
            device_info = get_active_device_info(get_device_id()) or {}
            return int(device_info.get("cabang_id") or 0)
        except (TypeError, ValueError, AttributeError, KeyError) as exc:
            self._log_reason("DISKON_CABANG_ID_FALLBACK", exc)
            return 0

    def _parse_date(self, value: Optional[str]) -> Optional[datetime]:
        if not value:
            return None
        val = str(value).strip()
        if val in ("0000-00-00", "0000-00-00 00:00:00"):
            return None
        try:
            if " " in val:
                return datetime.strptime(val, "%Y-%m-%d %H:%M:%S")
            return datetime.strptime(val, "%Y-%m-%d")
        except (TypeError, ValueError):
            return None

    def _parse_time(self, value: Optional[str]) -> Optional[time]:
        if not value:
            return None
        val = str(value).strip()
        try:
            if len(val.split(":")) == 2:
                return datetime.strptime(val, "%H:%M").time()
            return datetime.strptime(val, "%H:%M:%S").time()
        except (TypeError, ValueError):
            return None

    def _is_time_in_range(self, now_t: time, start_t: Optional[time], end_t: Optional[time]) -> bool:
        if not start_t and not end_t:
            return True
        if start_t and end_t:
            if start_t <= end_t:
                return start_t <= now_t <= end_t
            # lintas hari
            return now_t >= start_t or now_t <= end_t
        if start_t:
            return now_t >= start_t
        # edited by glg
        # MyPy guard: cabang ini hanya valid ketika end_t terisi.
        return bool(end_t is not None and now_t <= end_t)

    def _parse_rule_ids(self, diskon_log: Optional[str]) -> set:
        if not diskon_log:
            return set()
        text = str(diskon_log).strip().replace(" ", "")
        if not text:
            return set()

        result = set()
        for match in re.finditer(r"rule_id=(\d+)", text):
            try:
                result.add(int(match.group(1)))
            except (TypeError, ValueError):
                continue
        for match in re.finditer(r"rule_ids=([0-9,]+)", text):
            for token in match.group(1).split(","):
                token = token.strip()
                if not token:
                    continue
                try:
                    result.add(int(token))
                except (TypeError, ValueError):
                    continue
        return result

    def _resolve_quota_window(
        self,
        now: datetime,
        periode: Optional[str],
        start_date: Optional[datetime],
        stop_date: Optional[datetime],
    ):
        mode = str(periode or "").strip().lower()
        start_dt = None
        end_dt = None

        if mode in ("harian", "daily"):
            start_dt = datetime(now.year, now.month, now.day, 0, 0, 0)
            end_dt = start_dt + timedelta(days=1) - timedelta(seconds=1)
        elif mode in ("mingguan", "weekly"):
            week_start = now - timedelta(days=now.weekday())
            start_dt = datetime(week_start.year, week_start.month, week_start.day, 0, 0, 0)
            end_dt = start_dt + timedelta(days=7) - timedelta(seconds=1)
        elif mode in ("bulanan", "monthly"):
            start_dt = datetime(now.year, now.month, 1, 0, 0, 0)
            if now.month == 12:
                next_month = datetime(now.year + 1, 1, 1, 0, 0, 0)
            else:
                next_month = datetime(now.year, now.month + 1, 1, 0, 0, 0)
            end_dt = next_month - timedelta(seconds=1)
        elif mode in ("tahunan", "yearly", "annual"):
            start_dt = datetime(now.year, 1, 1, 0, 0, 0)
            end_dt = datetime(now.year, 12, 31, 23, 59, 59)

        if start_dt is None and start_date is not None:
            start_dt = datetime(start_date.year, start_date.month, start_date.day, 0, 0, 0)
        if end_dt is None and stop_date is not None:
            end_dt = datetime(stop_date.year, stop_date.month, stop_date.day, 23, 59, 59)

        return start_dt, end_dt

    def _count_rule_usage(
        self,
        rule_id: int,
        customer_id: int,
        start_dt: Optional[datetime],
        end_dt: Optional[datetime],
    ):
        if int(rule_id or 0) <= 0:
            return 0, 0

        params = []
        query = """
            SELECT customers_id, diskon_log
            FROM transaksi
            WHERE COALESCE(trash, 0) = 0
              AND (jenis_label = 'invoice' OR jenis_label IS NULL OR TRIM(jenis_label) = '')
              AND diskon_log IS NOT NULL
              AND TRIM(diskon_log) != ''
              AND (instr(diskon_log, 'rule_id=') > 0 OR instr(diskon_log, 'rule_ids=') > 0)
        """
        if start_dt is not None:
            query += " AND dtime >= ?"
            params.append(start_dt.strftime("%Y-%m-%d %H:%M:%S"))
        if end_dt is not None:
            query += " AND dtime <= ?"
            params.append(end_dt.strftime("%Y-%m-%d %H:%M:%S"))

        conn = connect_sqlite(self.db_path)
        conn.row_factory = lambda cur, row: {cur.description[i][0]: row[i] for i in range(len(row))}
        cursor = conn.cursor()
        try:
            cursor.execute(query, tuple(params))
            rows = cursor.fetchall() or []
        except sqlite3.Error as exc:
            self._log_reason("DISKON_RULE_USAGE_DB_ERROR", exc)
            rows = []
        finally:
            cursor.close()
            conn.close()

        total_used = 0
        customer_used = 0
        for row in rows:
            rule_ids = self._parse_rule_ids(row.get("diskon_log"))
            if rule_id not in rule_ids:
                continue
            total_used += 1
            try:
                trx_customer_id = int(row.get("customers_id") or 0)
            except (TypeError, ValueError):
                trx_customer_id = 0
            if trx_customer_id == int(customer_id or 0):
                customer_used += 1

        return total_used, customer_used

    def _safe_non_negative_int(self, value, default=0):
        try:
            parsed = int(value if value is not None else default)
            return parsed if parsed > 0 else 0
        except (TypeError, ValueError):
            return int(default or 0)

    def _safe_percent(self, value):
        try:
            percent = float(value or 0)
        except (TypeError, ValueError):
            return 0.0
        if percent < 0:
            return 0.0
        if percent > 100:
            return 100.0
        return percent

    def fetch_rules(self, customer_id: int, customer_level: int, cabang_id: int) -> List[Dict]:
        conn = connect_sqlite(self.db_path)
        conn.row_factory = lambda cur, row: {cur.description[i][0]: row[i] for i in range(len(row))}
        cursor = conn.cursor()
        try:
            cursor.execute(
                "SELECT name FROM sqlite_master WHERE type='table' AND name='diskon_customer'"
            )
            if not cursor.fetchone():
                return []
            cursor.execute(
                """
                SELECT *
                FROM diskon_customer
                WHERE status = 1
                  AND (trash = 0 OR trash IS NULL)
                  AND (cabang_id = 0 OR cabang_id IS NULL OR cabang_id = ?)
                  AND (
                        (customer_id = ?)
                     OR (customer_level = ?)
                     OR (customer_id = 0 AND customer_level = 0)
                  )
                """,
                (cabang_id, customer_id, customer_level),
            )
            return cursor.fetchall() or []
        except sqlite3.Error as exc:
            self._log_reason("DISKON_FETCH_RULES_DB_ERROR", exc)
            return []
        finally:
            cursor.close()
            conn.close()

    def calculate_benefits(self, total_belanja: float, customer_id: int) -> Dict:
        total_belanja = float(total_belanja or 0)
        customer_id = int(customer_id or 0)
        if total_belanja <= 0 or customer_id <= 0:
            return {
                "diskon_nilai": 0.0,
                "diskon_persen": 0.0,
                "cashback_nilai": 0.0,
                "point_nilai": 0.0,
                "rule": None,
                "diskon_rule": None,
                "cashback_rule": None,
                "point_rule": None,
                "applied_rule_ids": [],
                "rules": [],
            }

        customer_level = self.get_customer_level_id(customer_id)
        cabang_id = self.get_active_cabang_id()
        rules = self.fetch_rules(customer_id, customer_level, cabang_id)
        if not rules:
            return {
                "diskon_nilai": 0.0,
                "diskon_persen": 0.0,
                "cashback_nilai": 0.0,
                "point_nilai": 0.0,
                "rule": None,
                "diskon_rule": None,
                "cashback_rule": None,
                "point_rule": None,
                "applied_rule_ids": [],
                "rules": [],
            }

        now = datetime.now()
        now_t = now.time()
        diskon_best = 0.0
        diskon_best_rule = None
        cashback_best = 0.0
        cashback_best_rule = None
        point_best = 0.0
        point_best_rule = None
        usage_cache = {}

        for r in rules:
            try:
                minim = float(r.get("minim") or 0)
                maxim = float(r.get("maxim") or 0)
                if minim and total_belanja < minim:
                    continue
                if maxim and total_belanja > maxim:
                    continue

                start_date = self._parse_date(r.get("tanggal_start"))
                stop_date = self._parse_date(r.get("tanggal_stop"))
                raw_start_date = str(r.get("tanggal_start") or "").strip()
                raw_stop_date = str(r.get("tanggal_stop") or "").strip()
                if raw_start_date and start_date is None:
                    continue
                if raw_stop_date and stop_date is None:
                    continue
                if start_date and stop_date and start_date > stop_date:
                    continue
                if start_date and now.date() < start_date.date():
                    continue
                if stop_date and now.date() > stop_date.date():
                    continue

                start_time = self._parse_time(r.get("jam_start"))
                stop_time = self._parse_time(r.get("jam_stop"))
                raw_start_time = str(r.get("jam_start") or "").strip()
                raw_stop_time = str(r.get("jam_stop") or "").strip()
                if raw_start_time and start_time is None:
                    continue
                if raw_stop_time and stop_time is None:
                    continue
                if not self._is_time_in_range(now_t, start_time, stop_time):
                    continue

                rule_id = self._safe_non_negative_int(r.get("id"), default=0)
                quota_global = self._safe_non_negative_int(r.get("quota_global"), default=0)
                quota_per_customer = self._safe_non_negative_int(r.get("quota_per_customer"), default=0)
                if quota_global > 0 or quota_per_customer > 0:
                    if rule_id <= 0:
                        continue
                    quota_start, quota_end = self._resolve_quota_window(
                        now,
                        r.get("periode"),
                        start_date,
                        stop_date,
                    )
                    cache_key = (
                        rule_id,
                        quota_start.strftime("%Y-%m-%d %H:%M:%S") if quota_start else "",
                        quota_end.strftime("%Y-%m-%d %H:%M:%S") if quota_end else "",
                    )
                    if cache_key not in usage_cache:
                        usage_cache[cache_key] = self._count_rule_usage(
                            rule_id,
                            customer_id,
                            quota_start,
                            quota_end,
                        )
                    used_global, used_customer = usage_cache[cache_key]
                    if quota_global > 0 and used_global >= quota_global:
                        continue
                    if quota_per_customer > 0 and used_customer >= quota_per_customer:
                        continue

                tipe = str(r.get("tipe") or "").lower().strip()
                jenis = str(r.get("jenis") or "").lower().strip()
                persen = self._safe_percent(r.get("persen"))
                nilai = float(r.get("nilai") or 0)

                def _calc_amount():
                    if persen > 0:
                        return total_belanja * (persen / 100.0)
                    return max(nilai, 0.0)

                if tipe == "diskon" and jenis in ("transaksi", "birthday"):
                    val = _calc_amount()
                    if val > diskon_best:
                        diskon_best = val
                        diskon_best_rule = r
                elif tipe == "cashback":
                    val = _calc_amount()
                    if val > cashback_best:
                        cashback_best = val
                        cashback_best_rule = r
                elif tipe == "point":
                    val = _calc_amount()
                    if val > point_best:
                        point_best = val
                        point_best_rule = r
            except (TypeError, ValueError, KeyError, AttributeError):
                continue

        diskon_persen = 0.0
        if diskon_best_rule:
            try:
                diskon_persen = float(diskon_best_rule.get("persen") or 0)
            except (TypeError, ValueError):
                diskon_persen = 0.0

        applied_rule_ids = []
        for selected_rule in (diskon_best_rule, cashback_best_rule, point_best_rule):
            if not selected_rule:
                continue
            try:
                selected_id = int(selected_rule.get("id") or 0)
            except (TypeError, ValueError):
                selected_id = 0
            if selected_id > 0 and selected_id not in applied_rule_ids:
                applied_rule_ids.append(selected_id)

        return {
            "diskon_nilai": max(diskon_best, 0.0),
            "diskon_persen": max(diskon_persen, 0.0),
            "cashback_nilai": max(cashback_best, 0.0),
            "point_nilai": max(point_best, 0.0),
            "rule": diskon_best_rule,
            "diskon_rule": diskon_best_rule,
            "cashback_rule": cashback_best_rule,
            "point_rule": point_best_rule,
            "applied_rule_ids": applied_rule_ids,
            "rules": rules,
        }
