# src/guards.py
import logging
import math
from datetime import datetime
from src.utils.helpers import haversine_distance, jitter_seconds

log = logging.getLogger(__name__)

# ----------------------------------------------------------------------
# Helper utilities (kept lightweight here; heavy maths live in utils.helpers)
# ----------------------------------------------------------------------
def _ensure_positive(value, name):
    if value < 0:
        log.warning(f"{name} negative ({value}); clamping to 0.")
        return 0
    return value


# ----------------------------------------------------------------------
# Guard 1 – Wi‑Fi ↔ Location consistency
# ----------------------------------------------------------------------
def guard_wifi_location(device_state, venue_record):
    """
    device_state: dict with keys `ssid`, `location` (lat/lon), `timestamp`
    venue_record: dict from airports.yaml / venues.yaml containing
                  `wifi_ssids` (list) and optional `lat`, `lon`.
    Returns possibly‑corrected device_state.
    """
    ssid = device_state.get("ssid")
    if ssid not in venue_record.get("wifi_ssids", []):
        # pick a realistic SSID from the venue
        corrected = venue_record["wifi_ssids"][0]
        device_state["ssid"] = corrected
        # give a plausible RSSI for indoor venue
        device_state["rssi"] = -55
        log.info(
            f"wifi_mismatch_corrected – device {device_state.get('device')} "
            f"changed SSID to {corrected}"
        )
    # Optional: enforce RSSI range for indoor vs outdoor
    rssi = device_state.get("rssi", -70)
    if rssi > -30:
        device_state["rssi"] = -30
        log.info("rssi_clamped – too strong for indoor")
    elif rssi < -90:
        device_state["rssi"] = -90
        log.info("rssi_clamped – too weak for indoor")
    return device_state


# ----------------------------------------------------------------------
# Guard 2 – Step count ↔ travelled distance
# ----------------------------------------------------------------------
def guard_step_distance(prev_state, curr_state, avg_stride_m=0.78):
    """
    prev_state / curr_state: dicts with `lat`, `lon`, `steps`.
    Ensures steps roughly match the haversine distance.
    Adjusts `steps` if out of tolerance.
    """
    prev_loc = (prev_state["lat"], prev_state["lon"])
    curr_loc = (curr_state["lat"], curr_state["lon"])
    distance_m = haversine_distance(prev_loc, curr_loc)  # metres
    expected_steps = int(distance_m / avg_stride_m)

    actual_steps = curr_state.get("steps", 0) - prev_state.get("steps", 0)
    tolerance = 0.15 * expected_steps  # 15 % slack

    if abs(actual_steps - expected_steps) > tolerance:
        # Correct the step count to the expected value
        corrected_total = prev_state.get("steps", 0) + expected_steps
        curr_state["steps"] = corrected_total
        log.info(
            f"step_distance_mismatch – corrected steps from {actual_steps} "
            f"to {expected_steps} (dist {distance_m:.1f} m)"
        )
    return curr_state


# ----------------------------------------------------------------------
# Guard 3 – Heart‑rate ↔ activity state
# ----------------------------------------------------------------------
def guard_hr_activity(device_state, activity_state):
    """
    device_state: dict with `heart_rate`.
    activity_state: string e.g. "walking", "seated", "security_assist", "seatbelt_alert".
    Adjusts HR to stay within realistic bounds for the activity.
    """
    hr = device_state.get("heart_rate", 70)
    baseline = device_state.get("baseline_hr", 70)

    # Define expected HR ranges per activity
    ranges = {
        "walking": (baseline + 5, baseline + 20),
        "running": (baseline + 20, baseline + 40),
        "seated": (baseline - 5, baseline + 5),
        "security_assist": (baseline + 5, baseline + 12),
        "seatbelt_alert": (baseline + 5, baseline + 12),
        "elevator": (baseline - 2, baseline + 2),
        "conveyor_belt": (baseline - 2, baseline + 2),
    }

    low, high = ranges.get(activity_state, (baseline - 5, baseline + 5))
    if hr < low:
        device_state["heart_rate"] = low
        log.info(f"hr_adjusted – raised HR to {low} for activity {activity_state}")
    elif hr > high:
        device_state["heart_rate"] = high
        log.info(f"hr_adjusted – lowered HR to {high} for activity {activity_state}")

    return device_state


# ----------------------------------------------------------------------
# Guard 4 – Battery drain / charge sanity
# ----------------------------------------------------------------------
def guard_battery(device_state, event_type):
    """
    device_state: dict with `battery_percent`.
    event_type: string, e.g. "charging_start", "charging_stop", "regular_use".
    Ensures battery never rises without a charging event and never drops below 0.
    """
    battery = device_state.get("battery_percent", 100)

    if event_type == "charging_start":
        # No immediate change; just mark that charging is active
        device_state["charging"] = True
    elif event_type == "charging_stop":
        device_state["charging"] = False
    else:
        # Regular use – apply a tiny drain (0.1 % per minute simulated)
        drain = 0.1
        battery = max(0, battery - drain)
        device_state["battery_percent"] = round(battery, 2)

    # Clamp to 0‑100
    device_state["battery_percent"] = max(0, min(100, device_state["battery_percent"]))
    return device_state


# ----------------------------------------------------------------------
# Guard 5 – Monotonic timestamps per device
# ----------------------------------------------------------------------
def guard_timestamp_monotonic(prev_ts_str, new_ts_str, device_name):
    """
    prev_ts_str / new_ts_str: ISO‑8601 strings.
    Returns a corrected new timestamp (string) if out‑of‑order.
    """
    prev_ts = datetime.fromisoformat(prev_ts_str)
    new_ts = datetime.fromisoformat(new_ts_str)

    if new_ts <= prev_ts:
        # bump forward by a minimal jitter (1‑3 seconds)
        delta = jitter_seconds(1, 3)
        corrected = prev_ts + delta
        log.warning(
            f"timestamp_non_monotonic – device {device_name} corrected from "
            f"{new_ts_str} to {corrected.isoformat()}"
        )
        return corrected.isoformat()
    return new_ts_str


# ----------------------------------------------------------------------
# Guard 6 – Clipboard size & type sanity
# ----------------------------------------------------------------------
def guard_clipboard(entry):
    """
    entry: dict with `content_type`, `content`, `description`.
    Enforces size limits and valid types.
    """
    max_bytes = 10 * 1024  # 10 KB
    ctype = entry.get("content_type")
    content = entry.get("content", "")

    if ctype not in ("url", "text", "pdf_hash"):
        log.warning(f"clipboard_type_mismatch – forced to 'text'")
        entry["content_type"] = "text"

    if isinstance(content, str):
        size = len(content.encode("utf-8"))
        if size > max_bytes:
            # truncate and note it
            entry["content"] = content[: max_bytes // 2] + "...[truncated]"
            log.info("clipboard_too_large – truncated content")
    else:
        log.warning("clipboard_content_not_string – converting to str")
        entry["content"] = str(content)

    return entry


# ----------------------------------------------------------------------
# Guard 7 – Cross‑device event cohesion (timestamp window)
# ----------------------------------------------------------------------
def guard_cross_device_sync(event_batch, max_delta_seconds=2):
    """
    event_batch: list of dicts that belong to the same logical action
                 (e.g., clipboard_sync emitted for phone, laptop, watch).
    Ensures all timestamps lie within `max_delta_seconds`.
    If any are out of range, they are nudged to the median timestamp.
    Returns the possibly‑modified batch.
    """
    timestamps = [datetime.fromisoformat(e["timestamp"]) for e in event_batch]
    median_ts = sorted(timestamps)[len(timestamps) // 2]

    corrected_batch = []
    for ev, ts in zip(event_batch, timestamps):
        diff = abs((ts - median_ts).total_seconds())
        if diff > max_delta_seconds:
            new_ts = median_ts.isoformat()
            ev["timestamp"] = new_ts
            log.info(
                f"cross_device_time_skew – adjusted {ev['device']} "
                f"from {ts.isoformat()} to {new_ts}"
            )
        corrected_batch.append(ev)
    return corrected_batch


# ----------------------------------------------------------------------
# Master guard dispatcher – call this before emitting an event
# ----------------------------------------------------------------------
def run_all_guards(device_name, event_name, event_payload, persona_state):
    """
    device_name: "phone", "laptop", or "watch"
    event_name: string identifier of the event
    event_payload: dict that will be emitted
    persona_state: the shared Persona instance (allows access to previous state)
    Returns a possibly‑modified payload ready for emission.
    """
    # 1️⃣ Timestamp monotonicity (compare with last emitted timestamp)
    last_ts = persona_state.last_timestamps.get(device_name)
    if last_ts:
        event_payload["timestamp"] = guard_timestamp_monotonic(
            last_ts, event_payload["timestamp"], device_name
        )
    persona_state.last_timestamps[device_name] = event_payload["timestamp"]

    # 2️⃣ Device‑specific guards
    if event_name in ("wifi_connect", "wifi_disconnect"):
        venue = persona_state.current_venue  # should be set by itinerary before the call
        event_payload = guard_wifi_location(event_payload, venue)

    if event_name == "step_increment":
        prev = persona_state.prev_location_state
        curr = {
            "lat": persona_state.lat,
            "lon": persona_state.lon,
            "steps": persona_state.step_counter,
        }
        persona_state.prev_location_state = curr
        event_payload = guard_step_distance(prev, curr)

    if event_name in ("heart_rate_update", "seatbelt_announcement", "hr_normalization"):
        activity = persona_state.current_activity
        event_payload = guard_hr_activity(event_payload, activity)

    if event_name in ("battery_update", "charging_start", "charging_stop"):
        event_payload = guard_battery(event_payload, event_name)

    if event_name.startswith("clipboard"):
        # clipboard_copy, clipboard_sync, clipboard_paste
        entry = {
            "content_type": event_payload.get("content_type"),
            "content": event_payload.get("content"),
            "description": event_payload.get("description", ""),
        }
        event_payload = guard_clipboard(entry)

    # 3️⃣ Cross‑device sync (only needed when we have a batch;
    #    callers can invoke guard_cross_device_sync manually)
    return event_payload