import logging
import time
import uuid
import json
from urllib.parse import urlsplit, urlunsplit

try:
    import requests as _requests
except (ImportError, ModuleNotFoundError):
    _requests = None

from pypos.core.utils.config_utils import (
    get_jwt_runtime_config,
    normalize_base_url,
    read_endpoint_config,
    read_app_settings,
)
from pypos.core.utils.http_retry import request_with_retry as _request_with_retry
from pypos.core.utils.session_context import get_session_manager

# edited by glg
LOGGER = logging.getLogger(__name__)
HTTP_TRACE_SLOW_WARN_MS = 700


class BaseService:
    def __init__(self, http_client=None):
        self.http = http_client or _requests

    def set_http_client(self, http_client):
        self.http = http_client
        return self

    def _is_internal_url(self, url):
        try:
            cfg = read_endpoint_config()
            base_url = normalize_base_url(str(cfg.get("api_base_url") or ""))
            current = normalize_base_url(str(url or ""))
            if not base_url or not current:
                return False
            return current.startswith(base_url)
        except (OSError, ValueError, TypeError):
            return False

    # edited by glg
    @staticmethod
    def _sanitize_url_for_log(url):
        raw = str(url or "").strip()
        if not raw:
            return ""
        try:
            parsed = urlsplit(raw)
            if not parsed.scheme and not parsed.netloc:
                return raw.split("?", 1)[0].split("#", 1)[0]
            return urlunsplit((parsed.scheme, parsed.netloc, parsed.path, "", ""))
        except (ValueError, AttributeError):
            return raw.split("?", 1)[0].split("#", 1)[0]

    # edited by glg
    @staticmethod
    def _normalize_payload_scalar(value, max_len=500):
        if value is None:
            return None
        if isinstance(value, (bool, int, float)):
            return value
        if isinstance(value, (bytes, bytearray)):
            return {"_type": "bytes", "len": len(value)}
        text = str(value)
        if len(text) > int(max_len):
            return text[: int(max_len)] + "...(truncated)"
        return text

    # edited by glg
    @staticmethod
    def _merge_key_value(target, key, value):
        if key not in target:
            target[key] = value
            return
        existing = target[key]
        if isinstance(existing, list):
            existing.append(value)
            return
        target[key] = [existing, value]

    # edited by glg
    def _is_sensitive_key(self, key, redact_keys):
        key_text = str(key or "").strip().lower()
        if not key_text:
            return False
        for token in (redact_keys or set()):
            check = str(token or "").strip().lower()
            if not check:
                continue
            if check in key_text:
                return True
        return False

    # edited by glg
    def _normalize_payload_value(self, value, redact_keys, depth=0):
        if depth >= 2:
            return self._normalize_payload_scalar(value)
        if isinstance(value, dict):
            normalized = {}
            for key, raw in value.items():
                if self._is_sensitive_key(key, redact_keys):
                    normalized[str(key)] = "***REDACTED***"
                    continue
                normalized[str(key)] = self._normalize_payload_value(raw, redact_keys, depth=depth + 1)
            return normalized
        if isinstance(value, (list, tuple)):
            max_items = 20
            items = []
            for idx, item in enumerate(value):
                if idx >= max_items:
                    items.append(f"...(+{len(value) - max_items} item)")
                    break
                items.append(self._normalize_payload_value(item, redact_keys, depth=depth + 1))
            return items
        return self._normalize_payload_scalar(value)

    # edited by glg
    def _extract_files_payload_for_log(self, files, redact_keys):
        form_fields = {}
        file_fields = {}

        entries = []
        if isinstance(files, dict):
            entries = list(files.items())
        elif isinstance(files, (list, tuple)):
            for item in files:
                if isinstance(item, (list, tuple)) and len(item) >= 2:
                    entries.append((item[0], item[1]))

        for field_name, payload in entries:
            key = str(field_name or "")
            normalized_key = key or "-"
            filename = None
            content = payload
            content_type = None

            if isinstance(payload, (list, tuple)):
                if len(payload) >= 2:
                    filename = payload[0]
                    content = payload[1]
                if len(payload) >= 3:
                    content_type = payload[2]

            if filename not in (None, "", "None"):
                meta = {
                    "filename": self._normalize_payload_scalar(filename),
                }
                if content_type not in (None, ""):
                    meta["content_type"] = self._normalize_payload_scalar(content_type)
                if isinstance(content, (bytes, bytearray)):
                    meta["size"] = len(content)
                self._merge_key_value(file_fields, normalized_key, meta)
                continue

            if self._is_sensitive_key(normalized_key, redact_keys):
                normalized_value = "***REDACTED***"
            else:
                normalized_value = self._normalize_payload_value(content, redact_keys, depth=0)
            self._merge_key_value(form_fields, normalized_key, normalized_value)

        return form_fields, file_fields

    # edited by glg
    def _get_redact_keys(self):
        defaults = {
            "password",
            "passwd",
            "token",
            "authorization",
            "api_key",
            "pin",
            "kode_voucher",
            "voucher_code",
            "voucher",
        }
        try:
            settings = read_app_settings() or {}
            raw_keys = settings.get("log_redact_keys")
            if isinstance(raw_keys, str):
                values = [v.strip() for v in raw_keys.split(",")]
            elif isinstance(raw_keys, (list, tuple, set)):
                values = [str(v).strip() for v in raw_keys]
            else:
                values = []
            dynamic = {v.lower() for v in values if v}
            return defaults.union(dynamic)
        except (OSError, TypeError, ValueError):
            return defaults

    # edited by glg
    def _build_outbound_payload_snapshot(self, *, params, data, json_payload, request_kwargs):
        redact_keys = self._get_redact_keys()
        snapshot = {}

        if params is not None:
            snapshot["params"] = self._normalize_payload_value(params, redact_keys, depth=0)
        if json_payload is not None:
            snapshot["json"] = self._normalize_payload_value(json_payload, redact_keys, depth=0)
        if data is not None:
            snapshot["data"] = self._normalize_payload_value(data, redact_keys, depth=0)

        extra_kwargs = dict(request_kwargs or {})
        files = extra_kwargs.pop("files", None)
        if files is not None:
            form_fields, file_fields = self._extract_files_payload_for_log(files, redact_keys)
            if form_fields:
                snapshot["multipart_form"] = form_fields
            if file_fields:
                snapshot["multipart_files"] = file_fields

        if extra_kwargs:
            snapshot["extra_kwargs_keys"] = sorted(str(k) for k in extra_kwargs.keys())
        return snapshot

    # edited by glg
    def _log_outbound_payload_on_error(
        self,
        *,
        request_id,
        method,
        url,
        payload_snapshot,
        status_code=None,
        error=None,
        response=None,
    ):
        if not isinstance(payload_snapshot, dict) or not payload_snapshot:
            return
        event = {
            "req_id": str(request_id or ""),
            "method": str(method or "").upper(),
            "url": self._sanitize_url_for_log(url),
            "status": int(status_code or 0) if status_code else None,
            "error_type": error.__class__.__name__ if error else "",
            "error_detail": self._normalize_payload_scalar(str(error or ""), max_len=300),
            "request_payload": payload_snapshot,
        }
        if response is not None:
            response_text = self._normalize_payload_scalar(getattr(response, "text", "") or "", max_len=1200)
            if response_text:
                event["response_text"] = response_text
        try:
            LOGGER.warning("[HTTP_OUTBOUND_ERROR_PAYLOAD] %s", json.dumps(event, ensure_ascii=False))
        except (TypeError, ValueError):
            LOGGER.warning("[HTTP_OUTBOUND_ERROR_PAYLOAD] %s", str(event))

    # edited by glg
    def _log_http_trace(
        self,
        *,
        request_id,
        method,
        url,
        elapsed_ms,
        auth_required,
        status_code=None,
        refreshed=False,
        error=None,
    ):
        status_int = int(status_code or 0)
        has_error = error is not None
        is_slow = int(elapsed_ms or 0) >= int(HTTP_TRACE_SLOW_WARN_MS)
        is_http_error = status_int >= 400
        if not (has_error or is_slow or is_http_error):
            return

        cleaned_url = self._sanitize_url_for_log(url)
        base_message = (
            "[HTTP_TRACE] req_id=%s method=%s status=%s elapsed_ms=%s "
            "auth_required=%s refreshed=%s url=%s"
        )
        values = (
            str(request_id or ""),
            str(method or "").upper(),
            status_int if status_int else "-",
            int(elapsed_ms or 0),
            int(bool(auth_required)),
            int(bool(refreshed)),
            cleaned_url,
        )
        if has_error:
            LOGGER.warning(
                base_message + " error=%s detail=%s",
                *values,
                error.__class__.__name__,
                str(error),
            )
            return
        if is_slow or status_int >= 500:
            LOGGER.warning(base_message, *values)
            return
        LOGGER.info(base_message, *values)

    def _build_auth_headers(self, url, headers, auth_required):
        merged = dict(headers or {})
        if not auth_required:
            return merged
        jwt_cfg = get_jwt_runtime_config()
        if not jwt_cfg.get("enabled"):
            return merged
        if jwt_cfg.get("attach_internal_only") and not self._is_internal_url(url):
            return merged
        session = get_session_manager()
        if not session:
            if jwt_cfg.get("required"):
                raise RuntimeError("Session manager belum tersedia untuk JWT.")
            return merged
        self._ensure_fresh_access_token(session=session, jwt_cfg=jwt_cfg)
        access_token = str(session.get_access_token() or "").strip()
        token_type = str(session.get_token_type() or "Bearer").strip() or "Bearer"
        if not access_token:
            if jwt_cfg.get("required"):
                raise RuntimeError("Token JWT belum tersedia.")
            return merged
        merged["Authorization"] = f"{token_type} {access_token}"
        return merged

    def _ensure_fresh_access_token(self, session, jwt_cfg):
        if not session:
            return False
        if not jwt_cfg.get("enabled"):
            return False
        skew_seconds = int(jwt_cfg.get("refresh_skew_seconds") or 0)
        needs_refresh = False
        try:
            needs_refresh = bool(session.is_access_token_expiring(skew_seconds=skew_seconds))
        except (AttributeError, TypeError, ValueError):
            access_token = str(session.get_access_token() or "").strip()
            needs_refresh = not bool(access_token)
        if not needs_refresh:
            return False
        return bool(self._try_refresh_token())

    def _try_refresh_token(self):
        jwt_cfg = get_jwt_runtime_config()
        if not jwt_cfg.get("enabled") or not jwt_cfg.get("auto_refresh"):
            return False
        session = get_session_manager()
        if not session:
            return False
        refresh_token = str(session.get_refresh_token() or "").strip()
        if not refresh_token:
            return False
        try:
            from pypos.modules.auth.services.jwt_auth_service import JwtAuthService

            token_bundle = JwtAuthService(http_client=self.http).refresh_token(refresh_token=refresh_token)
            session.set_jwt_tokens(
                access_token=token_bundle.get("access_token"),
                refresh_token=token_bundle.get("refresh_token"),
                token_type=token_bundle.get("token_type"),
                expires_in=token_bundle.get("expires_in"),
            )
            return True
        except (ImportError, ModuleNotFoundError, RuntimeError, ValueError, TypeError, AttributeError):
            return False

    def request_with_retry(
        self,
        method,
        url,
        *,
        retries=None,
        backoff_seconds=None,
        retry_on=None,
        timeout=None,
        auth_required=True,
        **kwargs,
    ):
        if self.http is None:
            raise RuntimeError("HTTP client belum diset pada BaseService.")
        trace_id = uuid.uuid4().hex[:12]
        started_at = time.perf_counter()
        method_name = str(method or "").upper() or "GET"

        max_retry = None if retries is None else max(1, int(retries) + 1)
        request_kwargs = dict(kwargs or {})
        params = request_kwargs.pop("params", None)
        data = request_kwargs.pop("data", None)
        json_payload = request_kwargs.pop("json", None)
        headers = request_kwargs.pop("headers", None)
        payload_snapshot = self._build_outbound_payload_snapshot(
            params=params,
            data=data,
            json_payload=json_payload,
            request_kwargs=request_kwargs,
        )
        headers = self._build_auth_headers(url, headers, auth_required=auth_required)

        try:
            response = _request_with_retry(
                method,
                url,
                params=params,
                data=data,
                json=json_payload,
                headers=headers,
                timeout=timeout,
                max_retry=max_retry,
                backoff_sec=backoff_seconds,
                session=self.http,
                retry_on=retry_on,
                request_kwargs=request_kwargs,
            )
            elapsed_ms = int((time.perf_counter() - started_at) * 1000)
            self._log_http_trace(
                request_id=trace_id,
                method=method_name,
                url=url,
                elapsed_ms=elapsed_ms,
                auth_required=auth_required,
                status_code=getattr(response, "status_code", None),
            )
            return response
        except Exception as exc:
            response = getattr(exc, "response", None)
            status_code = getattr(response, "status_code", None)
            if int(status_code or 0) == 401 and auth_required and self._try_refresh_token():
                refreshed_headers = self._build_auth_headers(url, headers, auth_required=auth_required)
                try:
                    response = _request_with_retry(
                        method,
                        url,
                        params=params,
                        data=data,
                        json=json_payload,
                        headers=refreshed_headers,
                        timeout=timeout,
                        max_retry=max_retry,
                        backoff_sec=backoff_seconds,
                        session=self.http,
                        retry_on=retry_on,
                        request_kwargs=request_kwargs,
                    )
                    elapsed_ms = int((time.perf_counter() - started_at) * 1000)
                    self._log_http_trace(
                        request_id=trace_id,
                        method=method_name,
                        url=url,
                        elapsed_ms=elapsed_ms,
                        auth_required=auth_required,
                        status_code=getattr(response, "status_code", None),
                        refreshed=True,
                    )
                    return response
                except Exception as retry_exc:
                    elapsed_ms = int((time.perf_counter() - started_at) * 1000)
                    retry_response = getattr(retry_exc, "response", None)
                    self._log_outbound_payload_on_error(
                        request_id=trace_id,
                        method=method_name,
                        url=url,
                        payload_snapshot=payload_snapshot,
                        status_code=getattr(retry_response, "status_code", None),
                        error=retry_exc,
                        response=retry_response,
                    )
                    self._log_http_trace(
                        request_id=trace_id,
                        method=method_name,
                        url=url,
                        elapsed_ms=elapsed_ms,
                        auth_required=auth_required,
                        status_code=getattr(retry_response, "status_code", None),
                        refreshed=True,
                        error=retry_exc,
                    )
                    raise
            elapsed_ms = int((time.perf_counter() - started_at) * 1000)
            self._log_outbound_payload_on_error(
                request_id=trace_id,
                method=method_name,
                url=url,
                payload_snapshot=payload_snapshot,
                status_code=status_code,
                error=exc,
                response=response,
            )
            self._log_http_trace(
                request_id=trace_id,
                method=method_name,
                url=url,
                elapsed_ms=elapsed_ms,
                auth_required=auth_required,
                status_code=status_code,
                error=exc,
            )
            raise
