#!/usr/bin/env python3
"""BirdWing VPN - desktop client"""
import os, sys, json, time, uuid, platform, argparse, threading, subprocess, re, shutil, tempfile, hashlib, binascii, base64
from pathlib import Path
from urllib.parse import urlparse
import requests

SYSTEM = platform.system()
CONFIG_DIR = Path.home() / ".birdwing"
KEY_FILE = CONFIG_DIR / "keypair.json"
TOKEN_FILE = CONFIG_DIR / "token.json"
SETTINGS_FILE = CONFIG_DIR / "settings.json"
BIRDWING_HOST_DEFAULT = "http://birdwing.hosti.me"

def ensure_dirs():
    CONFIG_DIR.mkdir(parents=True, exist_ok=True)

_HAS_PYSIDE = False
_HAS_TEXTUAL = False
try:
    from PySide6.QtCore import QCoreApplication, QObject, Qt, Signal, Slot, QTimer
    from PySide6.QtGui import QFont, QIcon, QPixmap, QColor
    from PySide6.QtWidgets import (
        QApplication, QDialog, QFormLayout, QHBoxLayout, QHeaderView,
        QLabel, QLineEdit, QMainWindow, QMessageBox, QPlainTextEdit, QPushButton,
        QSystemTrayIcon, QTableWidget, QTableWidgetItem, QVBoxLayout, QWidget,
        QMenu, QFrame, QSplitter,
    )
    _HAS_PYSIDE = True
except ImportError:
    pass
try:
    from textual.app import App, ComposeResult
    from textual.containers import Container, Horizontal, Vertical, VerticalScroll
    from textual.widgets import Button, DataTable, Footer, Input, Label, RichLog, Static
    _HAS_TEXTUAL = True
except ImportError:
    pass


# ===== WireGuard Manager =====
"""
Cross-platform WireGuard manager for BirdWing client.
Supports privilege elevation (sudo on Linux, runas on Windows).
"""


SYSTEM = platform.system()
_SUDO_PASSWORD = None
_LAST_ELEVATION_FAILED = False
_ELEVATION_LOCK = threading.Lock()


def set_sudo_password(password: str):
    global _SUDO_PASSWORD
    _SUDO_PASSWORD = password


def clear_sudo_password():
    global _SUDO_PASSWORD
    _SUDO_PASSWORD = None


def last_elevation_failed() -> bool:
    with _ELEVATION_LOCK:
        return _LAST_ELEVATION_FAILED


def is_elevated() -> bool:
    return _SUDO_PASSWORD is not None


def _is_permission_error(e) -> bool:
    stderr = ""
    if isinstance(e, subprocess.CalledProcessError):
        try:
            stderr = (e.stderr or "").lower()
        except:
            pass
    elif isinstance(e, PermissionError):
        return True
    else:
        stderr = str(e).lower()
    for phrase in ["permission denied", "not permitted", "operation not permitted",
                   "requires root", "access denied", "a terminal is required",
                   "a password is required"]:
        if phrase in stderr:
            return True
    return False


def _run(cmd, capture_output=True, text=True, timeout=30, check=False, input_data=None):
    global _LAST_ELEVATION_FAILED
    with _ELEVATION_LOCK:
        _LAST_ELEVATION_FAILED = False

    try:
        result = subprocess.run(cmd, capture_output=capture_output, text=text,
                                timeout=timeout, check=False, input=input_data)
        if result.returncode == 0:
            return result
        if not _is_permission_error(subprocess.CalledProcessError(
                result.returncode, cmd, output=result.stdout, stderr=result.stderr)):
            if check:
                result.check_returncode()
            return result
    except PermissionError:
        pass
    except FileNotFoundError:
        raise
    except subprocess.TimeoutExpired:
        pass

    if _SUDO_PASSWORD and SYSTEM != "Windows":
        stdin_data = _SUDO_PASSWORD + "\n"
        if input_data:
            stdin_data += (input_data if isinstance(input_data, str)
                          else input_data.decode())
        try:
            result = subprocess.run(["sudo", "-S"] + cmd, capture_output=capture_output,
                                    text=text, timeout=timeout, check=False, input=stdin_data)
            if result.returncode == 0:
                return result
            with _ELEVATION_LOCK:
                _LAST_ELEVATION_FAILED = True
            return result
        except Exception:
            with _ELEVATION_LOCK:
                _LAST_ELEVATION_FAILED = True
            return None

    if SYSTEM == "Windows":
        try:
            result = subprocess.run(cmd, capture_output=capture_output, text=text,
                                    timeout=timeout, check=False, input=input_data)
            if result.returncode == 0:
                return result
            with _ELEVATION_LOCK:
                _LAST_ELEVATION_FAILED = True
            return result
        except (PermissionError, OSError):
            pass
        except FileNotFoundError:
            raise
        with _ELEVATION_LOCK:
            _LAST_ELEVATION_FAILED = True
        return None

    with _ELEVATION_LOCK:
        _LAST_ELEVATION_FAILED = True
    return None


def _find_wg_binary() -> str:
    if SYSTEM == "Windows":
        for c in [r"C:\Program Files\WireGuard\wg.exe",
                  r"C:\Program Files\WireGuard\wireguard.exe"]:
            if os.path.exists(c):
                return c
        if shutil.which("wg"):
            return "wg"
        return ""
    if shutil.which("wg"):
        return "wg"
    if shutil.which("wg-quick"):
        return "wg"
    return ""


def _find_wg_quick() -> str:
    if SYSTEM == "Windows":
        return ""
    return shutil.which("wg-quick") or ""


def is_wireguard_installed() -> bool:
    return bool(_find_wg_binary())


def gen_keypair() -> dict:
    wg = _find_wg_binary()
    if not wg:
        priv = os.urandom(32).hex()
        pub = _derive_pubkey(priv)
        return {"private_key": priv, "public_key": pub}
    priv = _run([wg, "genkey"], timeout=10)
    priv = (priv.stdout.strip() if priv and priv.stdout else os.urandom(32).hex())
    pub = _run([wg, "pubkey"], input_data=priv, timeout=10)
    pub = (pub.stdout.strip() if pub and pub.stdout
           else _derive_pubkey(priv if len(priv) == 64 else priv.encode().hex()))
    return {"private_key": priv, "public_key": pub}


def gen_deterministic_keypair(username: str, hostname: str) -> dict:
    seed = f"{username}::{hostname}::birdwing"
    digest = hashlib.sha256(seed.encode()).digest()
    priv_hex = digest.hex()
    pub = _derive_pubkey(priv_hex)
    return {"private_key": priv_hex, "public_key": pub}


def _derive_pubkey(priv_hex: str) -> str:
    try:
        from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
        from cryptography.hazmat.primitives import serialization
        key_bytes = binascii.unhexlify(priv_hex)
        priv_key = X25519PrivateKey.from_private_bytes(key_bytes)
        pub_key = priv_key.public_key()
        pub_bytes = pub_key.public_bytes(
            encoding=serialization.Encoding.Raw,
            format=serialization.PublicFormat.Raw,
        )
        return base64.b64encode(pub_bytes).decode()
    except ImportError:
        return "PUBKEY_PLACEHOLDER"


def generate_config(interface_ip: str, private_key: str, dns: str = "",
                    listen_port: int = 0) -> str:
    config = "[Interface]\n"
    config += f"Address = {interface_ip}/32\n"
    config += f"PrivateKey = {private_key}\n"
    if dns:
        config += f"DNS = {dns}\n"
    if listen_port:
        config += f"ListenPort = {listen_port}\n"
    return config


def add_peer_to_config(config: str, public_key: str, allowed_ips: str,
                       endpoint: str = "", persistent_keepalive: int = 25) -> str:
    config += f"\n[Peer]\n"
    config += f"PublicKey = {public_key}\n"
    config += f"AllowedIPs = {allowed_ips}\n"
    if endpoint:
        config += f"Endpoint = {endpoint}\n"
    if persistent_keepalive:
        config += f"PersistentKeepalive = {persistent_keepalive}\n"
    return config


def apply_config(interface_name: str, config_content: str) -> bool:
    if SYSTEM == "Windows":
        return _apply_config_windows(interface_name, config_content)
    return _apply_config_linux(interface_name, config_content)


def _list_wireguard_interfaces() -> list:
    """Return list of wireguard interface names visible via ip."""
    try:
        r = subprocess.run(["ip", "link", "show", "type", "wireguard"],
                           capture_output=True, text=True, timeout=10)
        if r.returncode != 0:
            return []
        names = []
        for line in r.stdout.split("\n"):
            m = re.match(r'^(\d+):\s*(\S+):', line)
            if m:
                names.append(m.group(2))
        return names
    except:
        return []


def _cleanup_stale_interfaces(current_interface: str):
    """Remove all wireguard interfaces that look like BirdWing leftovers."""
    for name in _list_wireguard_interfaces():
        if name == current_interface:
            continue
        if name.startswith("bw-") or re.match(r'^tmp\w{8,}', name):
            _run(["ip", "link", "delete", "dev", name], timeout=10)


def _apply_config_linux(interface_name: str, config_content: str) -> bool:
    wg_quick = _find_wg_quick()
    if not wg_quick:
        return _apply_config_linux_raw(interface_name, config_content)
    config_path = Path.home() / ".birdwing" / f"{interface_name}.conf"
    config_path.parent.mkdir(parents=True, exist_ok=True)
    config_path.write_text(config_content)
    _cleanup_stale_interfaces(interface_name)
    if _SUDO_PASSWORD and SYSTEM != "Windows":
        wg_cmd = ["sudo", "-S", wg_quick]
    else:
        wg_cmd = [wg_quick]
    try:
        _run(wg_cmd + ["down", interface_name], timeout=10)
        _run(wg_cmd + ["up", str(config_path)], timeout=30)
        if last_elevation_failed():
            return False
        if interface_name in _list_wireguard_interfaces():
            return True
        _run(wg_cmd + ["up", str(config_path)], timeout=30)
        return interface_name in _list_wireguard_interfaces()
    except:
        return False


def _apply_config_linux_raw(interface_name: str, config_content: str) -> bool:
    tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False)
    tmp.write(config_content)
    tmp.close()
    try:
        _run(["ip", "link", "add", "dev", interface_name, "type", "wireguard"], timeout=10)
        _run(["ip", "addr", "add", _extract_address(config_content), "dev", interface_name], timeout=10)
        _run(["wg", "setconf", interface_name, tmp.name], timeout=30)
        _run(["ip", "link", "set", "dev", interface_name, "up"], timeout=10)
        return not last_elevation_failed()
    except:
        return False
    finally:
        os.unlink(tmp.name)


def _find_windows_interface(hint: str) -> str:
    """Find a WireGuard interface name on Windows by wg show."""
    wg = _find_wg_binary()
    if not wg:
        return hint
    r = _run([wg, "show"], timeout=10)
    if r and r.stdout:
        for line in r.stdout.splitlines():
            name = line.split(":")[0].strip()
            if name:
                return name
    return hint


def _set_windows_interface_ip(iface_name: str, ip_cidr: str):
    """Set IP on a Windows WireGuard interface via PowerShell."""
    ip = ip_cidr.split("/")[0]
    ps_cmd = (
        f'$iface = Get-NetAdapter -Name "*{iface_name}*" | Select-Object -First 1; '
        f'if ($iface) {{ '
        f'  $i = $iface.Name; '
        f'  netsh interface ipv4 set address name="$i" source=static addr={ip} mask=255.255.255.0 gateway=none 2>&1 | Out-Null; '
        f'  netsh interface ipv4 add address "$i" {ip} 255.255.255.0 2>&1 | Out-Null; '
        f'}}'
    )
    subprocess.run(["powershell", "-Command", ps_cmd],
                   capture_output=True, timeout=15)


def _apply_config_windows(interface_name: str, config_content: str) -> bool:
    wg = _find_wg_binary()
    if not wg:
        return False
    wg_exe = wg
    if wg_exe.endswith("wg.exe"):
        wg_dir = os.path.dirname(wg_exe)
        wg_exe = os.path.join(wg_dir, "wireguard.exe")
        if not os.path.exists(wg_exe):
            wg_exe = wg
    config_dir = Path(os.environ.get("PROGRAMDATA", "C:\\ProgramData")) / "WireGuard" / "Configurations"
    config_dir.mkdir(parents=True, exist_ok=True)
    config_path = config_dir / f"{interface_name}.conf"
    config_path.write_text(config_content)

    # Primary: installtunnelservice (handles Address=/DNS= natively)
    try:
        r = _run([wg_exe, "/installtunnelservice", str(config_path)], timeout=30)
        if r is not None and r.returncode == 0 and not last_elevation_failed():
            return True
    except Exception:
        pass
    if last_elevation_failed():
        return False

    # Fallback: wg setconf (strip Address=, DNS=)
    stripped = []
    for line in config_content.splitlines():
        if re.match(r'^\s*(address|dns)\s*=', line.strip().lower()):
            continue
        stripped.append(line)
    stripped_cfg = "\n".join(stripped)
    tmp = config_dir / f"{interface_name}.stripped.conf"
    tmp.write_text(stripped_cfg)
    try:
        iface = _find_windows_interface(interface_name)
        r = _run([wg, "setconf", iface, str(tmp)], timeout=30)
        if r is not None and r.returncode == 0:
            ip = _extract_address(config_content)
            if ip:
                _set_windows_interface_ip(iface, ip)
            return True
    except Exception:
        pass
    finally:
        try:
            tmp.unlink()
        except Exception:
            pass
    return False


def remove_interface(interface_name: str) -> bool:
    if SYSTEM == "Windows":
        wg = _find_wg_binary()
        if not wg:
            return False
        wg_exe = wg
        if wg_exe.endswith("wg.exe"):
            wg_dir = os.path.dirname(wg_exe)
            wg_exe = os.path.join(wg_dir, "wireguard.exe")
            if not os.path.exists(wg_exe):
                return False
        try:
            result = _run([wg_exe, "/uninstalltunnelservice", interface_name], timeout=30)
            return result is not None and result.returncode == 0
        except:
            return False
    config_path = Path.home() / ".birdwing" / f"{interface_name}.conf"
    wg_quick = _find_wg_quick()
    if wg_quick:
        if _SUDO_PASSWORD and SYSTEM != "Windows":
            wg_cmd = ["sudo", "-S", wg_quick, "down"]
        else:
            wg_cmd = [wg_quick, "down"]
        if config_path.exists():
            _run(wg_cmd + [str(config_path)], timeout=10)
        else:
            _run(wg_cmd + [interface_name], timeout=10)
    ip_link = _run(["ip", "link", "delete", "dev", interface_name], timeout=10)
    return ip_link is not None and ip_link.returncode == 0


def get_interface_status(interface_name: str) -> dict:
    wg = _find_wg_binary()
    if not wg:
        return {"up": False, "error": "WireGuard not found"}
    try:
        result = _run([wg, "show", interface_name], timeout=10)
        if result is None or result.returncode != 0:
            return {"up": False, "error": "Interface not found"}
        return {"up": True, "output": result.stdout}
    except:
        return {"up": False, "error": str(sys.exc_info()[1])}


def _extract_address(config: str) -> str:
    m = re.search(r"Address\s*=\s*(\S+)", config)
    return m.group(1) if m else ""


def needs_elevation() -> bool:
    if SYSTEM == "Windows":
        try:
            return not ctypes.windll.shell32.IsUserAnAdmin()
        except:
            return True
    return os.geteuid() != 0


# ===== Client Core =====
class BirdWingClient:
    def __init__(self, host_url: str = None, log_callback=None):
        self.host_url = (host_url or os.environ.get("BIRDWING_HOST", "http://birdwing.hosti.me")).rstrip("/")
        self.log_callback = log_callback
        self.token = None
        self.username = None
        self.keypair = None
        iface_file = CONFIG_DIR / "interface_name"
        if iface_file.exists():
            self.interface_name = iface_file.read_text().strip()
        else:
            self.interface_name = f"bw-{uuid.uuid4().hex[:8]}"
            ensure_dirs()
            iface_file.write_text(self.interface_name)
        self.connected = False
        self.assigned_ip = None
        self.peers = []
        self.host_public_key = None
        self.host_listen_port = 51820
        self._last_error = None
        self._load_state()

    def log(self, msg, end="\n"):
        if self.log_callback:
            self.log_callback(msg)
        else:
            print(msg, end=end)

    # ─── Persistence ──────────────────────────────────────────

    def _load_state(self):
        ensure_dirs()
        if KEY_FILE.exists():
            try:
                self.keypair = json.loads(KEY_FILE.read_text())
            except Exception:
                pass
        if TOKEN_FILE.exists():
            try:
                data = json.loads(TOKEN_FILE.read_text())
                self.token = data.get("token")
                self.username = data.get("username")
            except Exception:
                pass
        if SETTINGS_FILE.exists():
            try:
                data = json.loads(SETTINGS_FILE.read_text())
                self.host_public_key = data.get("host_public_key")
                self.host_listen_port = data.get("host_listen_port", 51820)
                self.assigned_ip = data.get("assigned_ip")
            except Exception:
                pass
        if KEY_FILE.exists() and TOKEN_FILE.exists():
            pass

    def _save(self):
        ensure_dirs()
        try:
            KEY_FILE.write_text(json.dumps(self.keypair, indent=2))
            TOKEN_FILE.write_text(json.dumps({"token": self.token, "username": self.username}, indent=2))
            SETTINGS_FILE.write_text(json.dumps({
                "host_public_key": self.host_public_key,
                "host_listen_port": self.host_listen_port,
                "assigned_ip": self.assigned_ip,
                "connected": self.connected,
                "peer_count": len(self.peers),
                "timestamp": time.time(),
            }, indent=2))
            for f in [KEY_FILE, TOKEN_FILE, SETTINGS_FILE]:
                try:
                    f.chmod(0o600)
                except (AttributeError, NotImplementedError, OSError):
                    pass
        except Exception:
            pass

    # ─── API Helpers ──────────────────────────────────────────

    def _api(self, method, path, json_data=None, auth=True, timeout=15, **kwargs):
        headers = {}
        if auth and self.token:
            headers["Authorization"] = f"Bearer {self.token}"
        try:
            resp = requests.request(method, f"{self.host_url}{path}",
                                    json=json_data, headers=headers, timeout=timeout, **kwargs)
            return resp
        except requests.exceptions.ConnectionError:
            self.log(f"[!] Cannot connect to {self.host_url}")
            return None
        except Exception as e:
            self.log(f"[!] Request error: {e}")
            return None

    def _get(self, path, **kw):
        return self._api("GET", path, **kw)

    def _post(self, path, data=None, **kw):
        return self._api("POST", path, json_data=data, **kw)

    # ─── Authentication ───────────────────────────────────────

    def login(self, username: str, password: str) -> bool:
        resp = self._post("/api/auth/login", {"username": username, "password": password}, auth=False)
        if resp and resp.status_code == 200:
            data = resp.json()
            self.token = data["access_token"]
            self.username = username
            self._save()
            self.log(f"[+] Authenticated as {username}")
            self.fetch_settings()
            return True
        detail = "Unknown error"
        if resp and resp.status_code != 200:
            try:
                detail = resp.json().get("detail", detail)
            except Exception:
                pass
        self.log(f"[!] Login failed: {detail}")
        return False

    def login_if_needed(self) -> bool:
        if not self.token:
            self.log("[!] Not authenticated. Use login first.")
            return False
        if self._validate_token():
            return True
        self.log("[!] Session expired. Login again.")
        self.token = None
        self._save()
        return False

    def _validate_token(self) -> bool:
        """Check if the stored token is still valid with the server."""
        resp = self._get("/api/settings", timeout=8)
        if resp and resp.status_code == 200:
            return True
        if resp and resp.status_code == 401:
            return False
        return True  # if server unreachable, assume valid and let it fail later

    # ─── Settings ─────────────────────────────────────────────

    def fetch_settings(self) -> bool:
        if not self.token:
            return False
        resp = self._get("/api/settings")
        if resp and resp.status_code == 200:
            data = resp.json()
            self.host_public_key = data.get("host_public_key", self.host_public_key)
            self.host_listen_port = data.get("host_listen_port", self.host_listen_port)
            self._save()
            return True
        return False

    # ─── Key Management ───────────────────────────────────────

    def ensure_keypair(self):
        if self.keypair:
            return
        if self.username and platform.node():
            self.log(f"[*] Generating deterministic key from {self.username}@{platform.node()}...")
            self.keypair = gen_deterministic_keypair(self.username, platform.node())
        else:
            self.log("[*] Generating random WireGuard key pair...")
            self.keypair = gen_keypair()
        self._save()
        self.log(f"[+] Public key: {self.keypair['public_key'][:20]}...")

    # ─── Device Registration ──────────────────────────────────

    def ensure_registered(self) -> tuple:
        """Register if not already. Returns (ok, assigned_ip_or_error)."""
        if not self.login_if_needed():
            return False, "Not authenticated"
        self.ensure_keypair()

        device_name = f"{platform.node()}-{uuid.uuid4().hex[:4]}"
        resp = self._post("/api/devices/register", {
            "device_name": device_name,
            "public_key": self.keypair["public_key"],
        })
        if resp is None:
            return False, "Server unreachable"
        if resp.status_code == 401:
            self.token = None
            self._save()
            self.log("[!] Session expired — login again")
            return False, "Session expired"
        if resp.status_code == 409:
            self.log("[+] Device already registered")
            try:
                body = resp.json()
                ip_from_detail = body.get("detail", "")
                import re
                m = re.search(r'IP:\s*([\d.]+)', ip_from_detail)
                if m:
                    self.assigned_ip = m.group(1)
                    self._save()
            except Exception:
                pass
            if self.assigned_ip:
                return True, self.assigned_ip
            return True, "unknown"
        if resp.status_code == 200:
            data = resp.json()
            self.assigned_ip = data["assigned_ip"]
            self._save()
            self.log(f"[+] Device registered. IP: {self.assigned_ip}")
            return True, self.assigned_ip
        try:
            detail = resp.json().get("detail", "Unknown error")
        except Exception:
            detail = f"HTTP {resp.status_code}"
        self.log(f"[!] Registration failed: {detail}")
        return False, detail

    def register_device(self, device_name: str = "") -> bool:
        ok, msg = self.ensure_registered()
        return ok

    # ─── Peer Discovery ───────────────────────────────────────

    def fetch_peers(self) -> list:
        if not self.token:
            return []
        resp = self._get("/api/devices/peers")
        if resp and resp.status_code == 200:
            data = resp.json()
            self.assigned_ip = data.get("assigned_ip", self.assigned_ip)
            self.peers = data.get("peers", [])
            self._save()
            return self.peers
        return []

    # ─── WireGuard Config ─────────────────────────────────────

    def build_config(self) -> str:
        if not self.keypair:
            self.log("[!] No keypair available")
            return ""
        self.fetch_peers()
        hostname = urlparse(self.host_url).hostname
        dns = hostname or "172.255.0.1"
        config = generate_config(
            interface_ip=self.assigned_ip or "172.255.0.2",
            private_key=self.keypair["private_key"],
            dns=dns,
            listen_port=51821,
        )
        host_endpoint = f"{hostname}:{self.host_listen_port}"
        for peer in self.peers:
            ep = peer.get("endpoint", "") or (host_endpoint if peer.get("host") else "")
            if not ep and peer.get("host"):
                ep = host_endpoint
            config = add_peer_to_config(
                config,
                public_key=peer["public_key"],
                allowed_ips=peer["allowed_ips"],
                endpoint=ep,
            )
        return config

    # ─── Connect / Disconnect ─────────────────────────────────

    def connect(self) -> bool:
        self._last_error = None
        wg_up = get_interface_status(self.interface_name).get("up", False)
        if wg_up:
            self.connected = True
            self._save()
            self.log("[+] Already connected")
            return True
        self.connected = False

        ok, msg = self.ensure_registered()
        if not ok:
            if msg in ("Not authenticated", "Session expired"):
                self._last_error = "auth"
                self.log(f"[!] {msg} — cannot connect")
            return False

        self.fetch_peers()
        config = self.build_config()
        if not config:
            self._last_error = "config"
            return False

        config_path = CONFIG_DIR / f"{self.interface_name}.conf"
        ensure_dirs()
        config_path.write_text(config)

        if not is_wireguard_installed():
            self.log("[!] WireGuard not installed")
            if SYSTEM == "Windows":
                self.log("    Install from https://www.wireguard.com/install/")
            else:
                self.log("    Run: sudo pacman -S wireguard-tools / apt install wireguard")
            self.log(f"[*] Config saved to {config_path} (manual setup)")
            self.connected = False
            self._save()
            self._last_error = "no_wg"
            return False

        self.log(f"[*] Applying WireGuard config ({self.interface_name})...")
        if apply_config(self.interface_name, config):
            self.connected = True
            self._save()
            self.log(f"[+] VPN connected. IP: {self.assigned_ip}")
            return True

        self.log(f"[!] Failed to apply config. Saved to {config_path}")
        self.connected = False
        self._save()
        if last_elevation_failed():
            self._last_error = "elevation"
        else:
            self._last_error = "apply"
        return False

    def disconnect(self):
        self.log(f"[*] Removing interface {self.interface_name}...")
        remove_interface(self.interface_name)
        self.connected = False
        self._save()
        self.log("[+] Disconnected")

    # ─── Status ───────────────────────────────────────────────

    def get_status(self) -> dict:
        status = {
            "connection": "connected" if self.connected else "disconnected",
            "username": self.username,
            "assigned_ip": self.assigned_ip or "N/A",
            "peer_count": len(self.peers),
            "platform": SYSTEM,
            "wg_installed": is_wireguard_installed(),
            "host_public_key": self.host_public_key,
        }
        if self.connected and is_wireguard_installed():
            status["wg"] = get_interface_status(self.interface_name)
        return status

    def display_status(self):
        s = self.get_status()
        con = "Connected" if s["connection"] == "connected" else "Disconnected"
        self.log(f"  Connection : {con}")
        self.log(f"  Username   : {s['username'] or 'N/A'}")
        self.log(f"  Assigned IP: {s['assigned_ip']}")
        self.log(f"  Peers      : {s['peer_count']}")
        self.log(f"  Platform   : {s['platform']}")
        self.log(f"  WG Install : {'Yes' if s['wg_installed'] else 'No'}")
        if s.get("wg") and s["wg"].get("up"):
            out = s["wg"].get("output", "")
            for l in out.split("\n"):
                if "transfer:" in l.lower() or "handshake:" in l.lower():
                    self.log(f"  WG         : {l.strip()}")

    def display_peers(self):
        self.fetch_peers()
        if not self.peers:
            self.log("  No peers found.")
            return
        lines = ["  Host    IP               Device                          Key"]
        lines.append("  " + "-" * 76)
        for p in self.peers:
            host_tag = "★" if p.get("host") else " "
            key_short = p["public_key"][:20] + "..."
            lines.append(f"  {host_tag:<6} {p['assigned_ip']:<16} {p['device_name']:<30} {key_short}")
        self.log("\n".join(lines))

    # ─── Connection Test ──────────────────────────────────────

    def test_connection(self) -> dict:
        results = {"reachable": False, "latency_ms": None, "download_speed_mbps": None, "details": ""}
        wgi = get_interface_status(self.interface_name)
        if wgi.get("up"):
            results["details"] += f"WG interface {self.interface_name}: UP\n"
            for line in (wgi.get("output", "") or "").split("\n"):
                ls = line.strip()
                if ls:
                    results["details"] += f"  {ls}\n"
            host_ip = "172.255.0.1"
            try:
                ping_cmd = ["ping", "-n" if SYSTEM == "Windows" else "-c", "3", host_ip]
                ping = subprocess.run(ping_cmd, capture_output=True, text=True, timeout=15)
                if ping.returncode == 0:
                    results["reachable"] = True
                    match = re.search(r'(?:avg|Average|=\s*)(\d+\.?\d*)', ping.stdout)
                    if match:
                        results["latency_ms"] = float(match.group(1))
                    results["details"] += f"Ping {host_ip}: OK ({results['latency_ms']}ms)\n"
                else:
                    results["details"] += f"Ping {host_ip}: FAILED (no route to host?)\n"
            except Exception as e:
                results["details"] += f"Ping error: {e}\n"
        else:
            err = wgi.get("error", "")
            results["details"] += f"WG interface {self.interface_name}: DOWN ({err})\n"
        return results

    # ─── Health ───────────────────────────────────────────────

    def health_check(self) -> dict:
        resp = self._get("/api/health", auth=False)
        if resp and resp.status_code == 200:
            return resp.json()
        return {"status": "error", "detail": "Server unreachable" if resp is None else "Unknown"}

    # ─── Reset ────────────────────────────────────────────────

    def reset(self):
        self.log("[*] Resetting all local configuration...")
        if self.connected:
            self.disconnect()
        for f in [KEY_FILE, TOKEN_FILE, SETTINGS_FILE]:
            try:
                if f.exists():
                    f.unlink()
                    self.log(f"  Removed {f.name}")
            except Exception as e:
                self.log(f"  [!] Could not remove {f.name}: {e}")
        for f in CONFIG_DIR.glob("bw-*.conf"):
            try:
                f.unlink()
                self.log(f"  Removed {f.name}")
            except Exception:
                pass
        self.token = None
        self.username = None
        self.keypair = None
        self.assigned_ip = None
        self.peers = []
        self.connected = False
        self.log("[+] Reset complete. Login again to continue.")


def print_banner():
    print(r"""
   ____  _     _    ____ _    _ _____ _   _  ____
  | __ )(_)___| | _| __ (_)  | |_   _| \ | |/ ___|
  |  _ \| / __| |/ / |_) ||  | | | | |  \| | |  _
  | |_) | \__ \   <|  __/ |__| | | | | |\  | |_| |
  |____/|_|___/_|\_\_|   |____/| |_| |_| \_|\____|
                              Mesh VPN  v0.3
""")


def main_cli():
    p = argparse.ArgumentParser(description="BirdWing VPN Client")
    p.add_argument("--host", default=None, help="BirdWing Host URL")
    p.add_argument("--login", nargs=2, metavar=("USER", "PASS"), help="Login")
    p.add_argument("--register", metavar="NAME", nargs="?", const="", help="Register device")
    p.add_argument("--connect", action="store_true", help="Connect VPN")
    p.add_argument("--disconnect", action="store_true", help="Disconnect VPN")
    p.add_argument("--status", action="store_true", help="Show status")
    p.add_argument("--peers", action="store_true", help="List peers")
    p.add_argument("--config", action="store_true", help="Print WG config")
    p.add_argument("--ping", action="store_true", help="Ping VPN server (172.255.0.1)")
    p.add_argument("--test", action="store_true", help="Full connection test")
    p.add_argument("--health", action="store_true", help="Server health")
    p.add_argument("--reset", action="store_true", help="Reset local state")
    p.add_argument("--gui", action="store_true", help="Qt Desktop GUI (default if PySide6 available)")
    p.add_argument("--no-gui", action="store_true", help="Force CLI mode")
    p.add_argument("--tui", action="store_true", help="Textual TUI (deprecated, use --gui)")
    p.add_argument("--interactive", action="store_true", help="Interactive CLI")
    p.add_argument("--users", action="store_true", help="List users (admin)")
    p.add_argument("--devices", action="store_true", help="List all devices (admin)")
    p.add_argument("--disable-device", metavar="ID", type=int, help="Disable device (admin)")
    p.add_argument("--enable-device", metavar="ID", type=int, help="Enable device (admin)")
    p.add_argument("--key", action="store_true", help="Show your device's key")

    args = p.parse_args()

    if len(sys.argv) == 1:
        has_pyside = False
        try:
            from PySide6 import QtWidgets
            has_pyside = True
        except ImportError:
            pass
        if has_pyside and not args.no_gui:
            args.gui = True
        else:
            args.interactive = True

    client = BirdWingClient(host_url=args.host)

    if args.gui:
        try:
            return run_qt(client)
        except (ImportError, NameError) as e:
            print(f"[!] Qt GUI unavailable: {e}")
            print("[*] Falling back to CLI")

    if args.tui:
        print_banner()
        print(f"  Host: {client.host_url}")
        if not client.token:
            u = input("  Username: ")
            pw = input("  Password: ")
            client.login(u, pw)
        try:
            return run_tui(client)
        except (ImportError, NameError) as e:
            print(f"[!] TUI unavailable: {e}")
            print("[*] Falling back to interactive")

    if args.login:
        client.login(args.login[0], args.login[1])
    if args.register is not None:
        client.ensure_keypair()
        client.register_device(args.register or "")
    if args.health:
        print_banner()
        h = client.health_check()
        print(f"  Status  : {h.get('status', '?')}")
        print(f"  Version : {h.get('version', '?')}")
        print(f"  Users   : {h.get('users', '?')}")
        print(f"  Devices : {h.get('devices', '?')}")
    if args.connect:
        client.ensure_keypair()
        client.connect()
    if args.disconnect:
        client.disconnect()
    if args.status:
        print_banner()
        client.display_status()
    if args.peers:
        print_banner()
        client.display_peers()
    if args.config:
        client.ensure_keypair()
        c = client.build_config()
        if c:
            print(c)
    if args.reset:
        client.reset()
        return
    if args.ping:
        host = "172.255.0.1"
        cmd = ["ping", "-n" if SYSTEM == "Windows" else "-c", "4", host]
        print(f"  Pinging {host}...")
        try:
            r = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
            if r.returncode == 0:
                print(f"  ✅ Server reachable over VPN")
                for l in r.stdout.split("\n"):
                    if "ttl=" in l.lower() or "time=" in l.lower() or "avg" in l.lower() or "rtt" in l.lower():
                        print(f"     {l.strip()}")
            else:
                print(f"  ❌ No response from {host}")
                print(f"     {r.stderr.strip() or 'Host unreachable'}")
        except Exception as e:
            print(f"  ❌ Ping error: {e}")
        return
    if args.test:
        print_banner()
        r = client.test_connection()
        print(f"  {'Reachable':<20} {'Yes' if r['reachable'] else 'No'}")
        if r["latency_ms"]:
            print(f"  {'Latency':<20} {r['latency_ms']}ms")
        if r["download_speed_mbps"]:
            print(f"  {'Speed':<20} {r['download_speed_mbps']} Mbps")
        for line in r["details"].strip().split("\n"):
            if line:
                print(f"  {line}")

    if args.users:
        resp = client._get("/api/admin/users")
        if resp and resp.status_code == 200:
            data = resp.json()
            print(f"  Users ({len(data['users'])}):")
            for u in data['users']:
                print(f"    {u['id']:>3}  {u['username']:<20} {u['role']:<8} {u['created_at']}")
        else:
            print("  [!] Failed to list users (admin only)")

    if args.devices:
        resp = client._get("/api/admin/devices")
        if resp and resp.status_code == 200:
            data = resp.json()
            print(f"  Devices ({len(data['devices'])}):")
            for d in data['devices']:
                status = "ON" if d.get('enabled') else "OFF"
                print(f"    {d['id']:>3}  {d['device_name']:<25} {d['assigned_ip']:<15} {d['username']:<15} {status}")
        else:
            print("  [!] Failed to list devices (admin only)")

    if args.disable_device:
        resp = client._api("PATCH", f"/api/admin/devices/{args.disable_device}/toggle", json_data={"enabled": False})
        if resp and resp.status_code == 200:
            print(f"  [+] Device {args.disable_device} disabled")
        else:
            print(f"  [!] Failed to disable device")

    if args.enable_device:
        resp = client._api("PATCH", f"/api/admin/devices/{args.enable_device}/toggle", json_data={"enabled": True})
        if resp and resp.status_code == 200:
            print(f"  [+] Device {args.enable_device} enabled")
        else:
            print(f"  [!] Failed to enable device")

    if args.key:
        if client.keypair:
            print(f"  Private Key: {client.keypair['private_key']}")
            print(f"  Public Key : {client.keypair['public_key']}")
        else:
            print("  [!] No keypair. Login and register first.")

    if args.tui or args.interactive:
        print_banner()
        print(f"  Host: {client.host_url}")
        print("  Type 'help' for commands\n")
        if not client.token:
            u = input("  Username: ")
            pw = input("  Password: ")
            client.login(u, pw)

        live = [False]
        lt = [None]

        while True:
            try:
                cmd = input("\n  bw> ").strip().lower()
                parts = cmd.split()
                cmd = parts[0] if parts else ""
                if cmd in ("q", "quit", "exit"):
                    break
                elif cmd == "connect":
                    client.connect()
                elif cmd == "disconnect":
                    client.disconnect()
                elif cmd == "status":
                    client.display_status()
                elif cmd == "live":
                    if live[0]:
                        live[0] = False
                        print("  Live off")
                    else:
                        live[0] = True
                        def _loop():
                            while live[0]:
                                client.display_status()
                                time.sleep(5)
                        lt[0] = threading.Thread(target=_loop, daemon=True)
                        lt[0].start()
                        print("  Live on (every 5s)")
                elif cmd == "peers":
                    client.display_peers()
                elif cmd == "config":
                    print(client.build_config())
                elif cmd == "register":
                    client.register_device()
                elif cmd == "login":
                    u = input("    User: ")
                    pw = input("    Pass: ")
                    client.login(u, pw)
                elif cmd == "test":
                    r = client.test_connection()
                    print(f"  {'Reachable':<20} {'Yes' if r['reachable'] else 'No'}")
                    if r["latency_ms"]:
                        print(f"  {'Latency':<20} {r['latency_ms']}ms")
                    if r["download_speed_mbps"]:
                        print(f"  {'Speed':<20} {r['download_speed_mbps']} Mbps")
                    for line in r["details"].strip().split("\n"):
                        if line:
                            print(f"  {line}")
                elif cmd == "health":
                    h = client.health_check()
                    print(f"  Status  : {h.get('status', '?')}")
                    print(f"  Version : {h.get('version', '?')}")
                    print(f"  Users   : {h.get('users', '?')}")
                    print(f"  Devices : {h.get('devices', '?')}")
                elif cmd == "reset":
                    client.reset()
                elif cmd == "users":
                    resp = client._get("/api/admin/users")
                    if resp and resp.status_code == 200:
                        data = resp.json()
                        print(f"  Users ({len(data['users'])}):")
                        for u in data['users']:
                            print(f"    {u['id']:>3}  {u['username']:<20} {u['role']:<8} {u['created_at']}")
                    else:
                        print("  [!] Failed (admin only)")
                elif cmd == "devices":
                    resp = client._get("/api/admin/devices")
                    if resp and resp.status_code == 200:
                        data = resp.json()
                        print(f"  Devices ({len(data['devices'])}):")
                        for d in data['devices']:
                            status = "ON" if d.get('enabled') else "OFF"
                            print(f"    {d['id']:>3}  {d['device_name']:<25} {d['assigned_ip']:<15} {d['username']:<15} {status}")
                    else:
                        print("  [!] Failed (admin only)")
                elif cmd == "disable" and len(parts) > 1:
                    did = parts[1]
                    resp = client._api("PATCH", f"/api/admin/devices/{did}/toggle", json_data={"enabled": False})
                    if resp and resp.status_code == 200:
                        print(f"  [+] Device {did} disabled")
                    else:
                        print("  [!] Failed (admin only)")
                elif cmd == "enable" and len(parts) > 1:
                    did = parts[1]
                    resp = client._api("PATCH", f"/api/admin/devices/{did}/toggle", json_data={"enabled": True})
                    if resp and resp.status_code == 200:
                        print(f"  [+] Device {did} enabled")
                    else:
                        print("  [!] Failed (admin only)")
                elif cmd == "key":
                    if client.keypair:
                        print(f"  Private Key: {client.keypair['private_key']}")
                        print(f"  Public Key : {client.keypair['public_key']}")
                    else:
                        print("  [!] Login and register first.")
                elif cmd == "help":
                    print("  Commands: login, register, connect, disconnect,")
                    print("            status, peers, config, test, health,")
                    print("            live, reset, users, devices,")
                    print("            enable <id>, disable <id>, key,")
                    print("            help, quit")
                else:
                    print(f"  Unknown: {cmd}. Type 'help'")
            except KeyboardInterrupt:
                print()
                break
            except EOFError:
                break
        live[0] = False
        client.disconnect()


# ===== Qt GUI =====
"""
BirdWing Qt Desktop GUI
"""

import os, sys, threading, time, re
import subprocess

from PySide6.QtCore import QCoreApplication, QObject, Qt, Signal, Slot, QTimer
from PySide6.QtGui import QFont, QIcon, QPixmap, QColor
from PySide6.QtWidgets import (
    QApplication, QComboBox, QDialog, QFormLayout, QHBoxLayout, QHeaderView,
    QLabel, QLineEdit, QMainWindow, QMessageBox, QPlainTextEdit, QPushButton,
    QSystemTrayIcon, QTableWidget, QTableWidgetItem, QVBoxLayout, QWidget,
    QMenu, QFrame, QSplitter,
)



class LogSignal(QObject):
    log_line = Signal(str)


class BirdWingQtApp(QMainWindow):
    def __init__(self, client):
        super().__init__()
        self.client = client
        self.client.log_callback = self._on_log
        self._signal = LogSignal()
        self._signal.log_line.connect(self._append_log)
        self._timer = QTimer(self)
        self._timer.timeout.connect(self._refresh)
        self._tasks = {}
        self._initialized = False

        self._build_ui()
        self._refresh()
        self._timer.start(3000)

    def _build_ui(self):
        self.setWindowTitle("BirdWing VPN")
        self.setMinimumSize(900, 600)
        self.resize(1080, 700)
        self.setStyleSheet("""
            QMainWindow { background: #0a0e27; }
            QDialog { background: #111432; }
            QWidget { color: #e2e8f0; font-family: 'Segoe UI', system-ui, -apple-system, sans-serif; font-size: 13px; }
            QPushButton { background: #1e293b; color: #e2e8f0; border: 1px solid #334155; border-radius: 8px; padding: 10px 18px; font-weight: 500; min-height: 22px; }
            QPushButton:hover { background: #334155; border-color: #475569; }
            QPushButton:pressed { background: #0f172a; }
            QPushButton:disabled { color: #475569; border-color: #1e293b; }
            QPushButton#btn_connect { background: #059669; color: #fff; border-color: #059669; font-weight: 600; }
            QPushButton#btn_connect:hover { background: #10b981; }
            QPushButton#btn_reset, QPushButton#btn_admin_wg_stop { background: #dc2626; color: #fff; border-color: #dc2626; font-weight: 600; }
            QPushButton#btn_my_devices:hover, QPushButton#btn_admin_wg_stop:hover { background: #ef4444; }
            QPushButton#btn_admin_wg_start { background: #059669; color: #fff; border-color: #059669; font-weight: 600; }
            QPushButton#btn_admin_wg_start:hover { background: #10b981; }
            QPushButton#btn_sudo { background: #d97706; color: #fff; border-color: #d97706; font-weight: 600; }
            QPushButton#btn_sudo:hover { background: #f59e0b; }
            QTableWidget { background: #111432; color: #e2e8f0; border: 1px solid #1e293b; border-radius: 8px; gridline-color: #1e293b; }
            QTableWidget::item { padding: 8px 10px; border-bottom: 1px solid #1e293b; }
            QTableWidget::item:selected { background: #1e3a5f; }
            QHeaderView::section { background: #0a0e27; color: #64748b; border: none; padding: 10px 10px; font-weight: 700; font-size: 10px; letter-spacing: 1px; border-bottom: 2px solid #1e293b; }
            QFrame#statusBar { background: #111432; border: 1px solid #1e293b; border-radius: 10px; padding: 8px 14px; }
            QPlainTextEdit { background: #0a0e27; color: #94a3b8; border: 1px solid #1e293b; border-radius: 8px; padding: 10px; font-family: 'Cascadia Code', 'Fira Code', 'Courier New', monospace; font-size: 11px; }
            QLineEdit { background: #111432; color: #e2e8f0; border: 1px solid #334155; border-radius: 6px; padding: 9px 14px; font-size: 13px; }
            QLineEdit:focus { border-color: #38bdf8; }
            QComboBox { background: #111432; color: #e2e8f0; border: 1px solid #334155; border-radius: 6px; padding: 8px 14px; }
            QComboBox:hover { border-color: #475569; }
            QComboBox::drop-down { border: none; width: 24px; }
            QFrame[frameShape="4"] { border: none; border-top: 1px solid #1e293b; margin: 4px 0; }
            QFrame[frameShape="5"] { border: none; border-left: 1px solid #1e293b; margin: 0 4px; }
            QSplitter::handle { background: #1e293b; width: 1px; margin: 0 4px; }
            QScrollBar:vertical { background: transparent; width: 8px; }
            QScrollBar::handle:vertical { background: #334155; border-radius: 4px; min-height: 24px; }
            QScrollBar::handle:vertical:hover { background: #475569; }
            QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical { height: 0; }
            QScrollBar:horizontal { background: transparent; height: 8px; }
            QScrollBar::handle:horizontal { background: #334155; border-radius: 4px; min-width: 24px; }
            QScrollBar::handle:horizontal:hover { background: #475569; }
            QScrollBar::add-line:horizontal, QScrollBar::sub-line:horizontal { width: 0; }
        """)

        cw = QWidget()
        self.setCentralWidget(cw)
        ml = QVBoxLayout(cw)
        ml.setContentsMargins(24, 20, 24, 16)
        ml.setSpacing(16)

        # ── Splitter: side panel + peers table ──
        splitter = QSplitter(Qt.Horizontal)

        left = QWidget()
        left.setObjectName("sidePanel")
        left.setStyleSheet("QWidget#sidePanel { background: #111432; border: 1px solid #1e293b; border-radius: 12px; }")
        ll = QVBoxLayout(left)
        ll.setContentsMargins(12, 14, 12, 14)
        ll.setSpacing(2)
        left.setFixedWidth(220)

        def _make_section_label(text):
            lbl = QLabel(text)
            lbl.setStyleSheet("color: #38bdf8; font-weight: 700; font-size: 10px; letter-spacing: 1.5px; padding: 8px 4px 4px 4px;")
            return lbl

        # Connection group
        self.btn_login = QPushButton("Sign In"); self.btn_login.setObjectName("btn_login"); self.btn_login.clicked.connect(self._do_login); ll.addWidget(self.btn_login)
        self.btn_logout = QPushButton("Sign Out"); self.btn_logout.setObjectName("btn_logout"); self.btn_logout.clicked.connect(self._do_logout); ll.addWidget(self.btn_logout)
        self.btn_connect = QPushButton("Connect VPN"); self.btn_connect.setObjectName("btn_connect"); self.btn_connect.clicked.connect(self._do_connect); ll.addWidget(self.btn_connect)
        self.btn_disconnect = QPushButton("Disconnect"); self.btn_disconnect.setObjectName("btn_disconnect"); self.btn_disconnect.clicked.connect(lambda: self._bg("disconnect", self.client.disconnect)); ll.addWidget(self.btn_disconnect)

        sep = QFrame(); sep.setFrameShape(QFrame.HLine); ll.addWidget(sep)

        # Tools
        ll.addWidget(_make_section_label("TOOLS"))
        self.btn_my_devices = QPushButton("My Devices"); self.btn_my_devices.setObjectName("btn_my_devices"); self.btn_my_devices.clicked.connect(self._do_my_devices); ll.addWidget(self.btn_my_devices)
        self.btn_test = QPushButton("Connection Test"); self.btn_test.setObjectName("btn_test"); self.btn_test.clicked.connect(self._do_test); ll.addWidget(self.btn_test)
        self.btn_peers = QPushButton("Refresh Peers"); self.btn_peers.setObjectName("btn_peers"); self.btn_peers.clicked.connect(self._do_peers); ll.addWidget(self.btn_peers)
        self.btn_health = QPushButton("Server Health"); self.btn_health.setObjectName("btn_health"); self.btn_health.clicked.connect(self._do_health); ll.addWidget(self.btn_health)

        sep2 = QFrame(); sep2.setFrameShape(QFrame.HLine); ll.addWidget(sep2)

        # Admin
        ll.addWidget(_make_section_label("ADMIN"))
        self.btn_change_pass = QPushButton("Change Password"); self.btn_change_pass.setObjectName("btn_change_pass"); self.btn_change_pass.clicked.connect(self._do_change_password); ll.addWidget(self.btn_change_pass)
        self.btn_admin_users = QPushButton("Manage Users"); self.btn_admin_users.setObjectName("btn_admin_users"); self.btn_admin_users.clicked.connect(self._do_admin_users); ll.addWidget(self.btn_admin_users)
        self.btn_admin_devices = QPushButton("All Devices"); self.btn_admin_devices.setObjectName("btn_admin_devices"); self.btn_admin_devices.clicked.connect(self._do_admin_devices); ll.addWidget(self.btn_admin_devices)
        self.btn_admin_logs = QPushButton("View Logs"); self.btn_admin_logs.setObjectName("btn_admin_logs"); self.btn_admin_logs.clicked.connect(self._do_admin_logs); ll.addWidget(self.btn_admin_logs)
        self.btn_admin_wg = QPushButton("WG Diagnostics"); self.btn_admin_wg.setObjectName("btn_admin_wg"); self.btn_admin_wg.clicked.connect(self._do_admin_wg); ll.addWidget(self.btn_admin_wg)
        self.btn_admin_wg_stop = QPushButton("WireGuard Stop"); self.btn_admin_wg_stop.setObjectName("btn_admin_wg_stop"); self.btn_admin_wg_stop.clicked.connect(lambda: self._bg("wg_stop", lambda: self._admin_api("POST", "/api/admin/wg/stop"))); ll.addWidget(self.btn_admin_wg_stop)
        self.btn_admin_wg_start = QPushButton("WireGuard Start"); self.btn_admin_wg_start.setObjectName("btn_admin_wg_start"); self.btn_admin_wg_start.clicked.connect(lambda: self._bg("wg_start", lambda: self._admin_api("POST", "/api/admin/wg/start"))); ll.addWidget(self.btn_admin_wg_start)
        for b in [self.btn_change_pass, self.btn_admin_users, self.btn_admin_devices, self.btn_admin_logs, self.btn_admin_wg, self.btn_admin_wg_stop, self.btn_admin_wg_start]:
            b.setVisible(False)

        sep3 = QFrame(); sep3.setFrameShape(QFrame.HLine); ll.addWidget(sep3)

        self.btn_sudo = QPushButton("Elevate Privileges")
        self.btn_sudo.setObjectName("btn_sudo")
        self.btn_sudo.clicked.connect(self._do_sudo)
        self.btn_sudo.setVisible(False)
        ll.addWidget(self.btn_sudo)

        self.btn_reset = QPushButton("Factory Reset")
        self.btn_reset.setObjectName("btn_reset")
        self.btn_reset.clicked.connect(lambda: self._bg("reset", self.client.reset))
        ll.addWidget(self.btn_reset)

        ll.addStretch()

        ll.addStretch()
        splitter.addWidget(left)

        # Peers table
        self.peers_table = QTableWidget(0, 4)
        self.peers_table.setHorizontalHeaderLabels(["Host", "IP", "Device", "Public Key"])
        self.peers_table.horizontalHeader().setStretchLastSection(True)
        self.peers_table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
        self.peers_table.setSelectionBehavior(QTableWidget.SelectRows)
        self.peers_table.setEditTriggers(QTableWidget.NoEditTriggers)
        self.peers_table.verticalHeader().setVisible(False)
        self.peers_table.setAlternatingRowColors(True)
        splitter.addWidget(self.peers_table)
        splitter.setSizes([200, 880])
        ml.addWidget(splitter, 1)

        # ── Status bar ──
        self.sb = QFrame()
        self.sb.setObjectName("statusBar")
        self.sb.setFrameShape(QFrame.StyledPanel)
        sb_lay = QHBoxLayout(self.sb)
        sb_lay.setContentsMargins(12, 6, 12, 6)
        sb_lay.setSpacing(16)
        self.s_con = QLabel("DISCONNECTED")
        self.s_con.setStyleSheet("font-weight: bold; color: #e74c3c;")
        sb_lay.addWidget(self.s_con)
        vsep = QFrame(); vsep.setFrameShape(QFrame.VLine); vsep.setStyleSheet("border: none; border-left: 1px solid #334155;"); sb_lay.addWidget(vsep)
        self.s_self = QLabel("")
        self.s_self.setStyleSheet("color: #38bdf8; font-weight: 600;")
        sb_lay.addWidget(self.s_self)
        self.s_host = QLabel("")
        self.s_host.setStyleSheet("color: #94a3b8;")
        sb_lay.addWidget(self.s_host)
        self.s_priv = QLabel("")
        sb_lay.addWidget(self.s_priv)
        self.s_user = QLabel("")
        self.s_user.setStyleSheet("color: #94a3b8;")
        sb_lay.addWidget(self.s_user)
        self.s_ip = QLabel("")
        self.s_ip.setStyleSheet("color: #94a3b8;")
        sb_lay.addWidget(self.s_ip)
        self.s_peers = QLabel("")
        self.s_peers.setStyleSheet("color: #94a3b8;")
        sb_lay.addWidget(self.s_peers)
        self.s_wg = QLabel("")
        self.s_wg.setStyleSheet("color: #22c55e; font-size: 11px;")
        sb_lay.addWidget(self.s_wg)
        sb_lay.addStretch()
        ml.addWidget(self.sb)

        # ── Log panel ──
        self.log_output = QPlainTextEdit()
        self.log_output.setReadOnly(True)
        self.log_output.setMaximumBlockCount(500)
        f = QFont("monospace", 9)
        self.log_output.setFont(f)
        self.log_output.setFixedHeight(100)
        ml.addWidget(self.log_output)

        self._setup_tray()

    def _setup_tray(self):
        self.tray = QSystemTrayIcon(self)
        pm = QPixmap(16, 16)
        pm.fill(Qt.gray)
        self.tray.setIcon(QIcon(pm))
        self.tray.setToolTip("BirdWing VPN")
        menu = QMenu()
        menu.addAction("Show/Hide").triggered.connect(self._toggle_visible)
        menu.addSeparator()
        menu.addAction("Connect").triggered.connect(self._do_connect)
        menu.addAction("Disconnect").triggered.connect(lambda: self._bg("disconnect", self.client.disconnect))
        menu.addSeparator()
        menu.addAction("Quit").triggered.connect(self._quit)
        self.tray.setContextMenu(menu)
        self.tray.activated.connect(
            lambda r: self._toggle_visible() if r == QSystemTrayIcon.DoubleClick else None
        )
        self.tray.show()

    def _toggle_visible(self):
        if self.isVisible():
            self.hide()
        else:
            self.show()
            self.raise_()

    def closeEvent(self, event):
        event.ignore()
        msg = QMessageBox(self)
        msg.setWindowTitle("BirdWing VPN")
        msg.setText("Are you sure homeboy?")
        msg.setInformativeText("What do you want to do?")
        msg.setIcon(QMessageBox.Question)
        min_btn = msg.addButton("Minimize to Tray", QMessageBox.AcceptRole)
        cancel_btn = msg.addButton("Cancel", QMessageBox.RejectRole)
        quit_btn = msg.addButton("Quit", QMessageBox.DestructiveRole)
        msg.setDefaultButton(cancel_btn)
        msg.exec()
        if msg.clickedButton() == min_btn:
            self.hide()
            self.tray.showMessage("BirdWing", "Minimized to tray", QSystemTrayIcon.Information, 2000)
        elif msg.clickedButton() == quit_btn:
            self._quit()

    def _ensure_sudo_for_quit(self):
        """If WG is connected and no sudo in memory, prompt for it. Returns True if safe to quit."""
        wg_up = False
        iface = self.client.interface_name
        if not iface:
            return True
        if SYSTEM != "Windows" and os.path.exists(f"/sys/class/net/{iface}"):
            wg_up = True
        else:
            try:
                wg_info = get_interface_status(iface)
                wg_up = wg_info.get("up", False)
            except Exception:
                pass
        if not wg_up:
            return True
        if is_elevated():
            return True
        dlg = QDialog(self)
        dlg.setWindowTitle("Privilege Required")
        dlg.setMinimumWidth(400)
        dlg.setStyleSheet(self.styleSheet())
        layout = QVBoxLayout(dlg)
        layout.setSpacing(12)
        layout.addWidget(QLabel("WireGuard is still connected.\nEnter your sudo password to disconnect cleanly."))
        pw_inp = QLineEdit()
        pw_inp.setEchoMode(QLineEdit.Password)
        pw_inp.setPlaceholderText("sudo password")
        layout.addWidget(pw_inp)
        btns = QHBoxLayout()
        ok_btn = QPushButton("Disconnect & Quit")
        ok_btn.setObjectName("btn_connect")
        cancel_btn = QPushButton("Leave Running")
        btns.addWidget(ok_btn)
        btns.addWidget(cancel_btn)
        layout.addLayout(btns)
        ok_btn.clicked.connect(dlg.accept)
        cancel_btn.clicked.connect(dlg.reject)
        if dlg.exec() != QDialog.Accepted:
            return False
        pw = pw_inp.text()
        if pw:
            try:
                r = subprocess.run(["sudo", "-S", "true"], input=pw + "\n",
                                   capture_output=True, text=True, timeout=10)
                if r.returncode == 0:
                    set_sudo_password(pw)
                    return True
            except Exception:
                pass
        QMessageBox.warning(self, "Error", "Invalid password — quitting anyway, VPN may stay active.")
        return True

    def _quit(self):
        if not self._ensure_sudo_for_quit():
            return
        self.client.disconnect()
        QCoreApplication.instance().quit()

    def _on_log(self, msg):
        self._signal.log_line.emit(msg)

    @Slot(str)
    def _append_log(self, msg):
        clean = re.sub(r'\[/?\w+\]', '', msg)
        self.log_output.appendPlainText(f"[{time.strftime('%H:%M:%S')}] {clean}")

    def _refresh(self):
        try:
            s = self.client.get_status()
        except Exception:
            return

        authed = bool(self.client.token)
        self._refresh_admin_buttons()
        wg_up = False
        wg_info = None
        if is_wireguard_installed():
            wg_info = get_interface_status(self.client.interface_name)
            wg_up = wg_info.get("up", False)

        pw_set = is_elevated()
        need_elev = authed and needs_elevation() and is_wireguard_installed()
        self.btn_sudo.setVisible(need_elev)
        self.btn_sudo.setEnabled(not pw_set)
        self.btn_sudo.setText("Elevated ✓" if pw_set else "Elevate Privileges")

        self.btn_login.setVisible(not authed)
        self.btn_logout.setVisible(authed)
        for b in [self.btn_my_devices, self.btn_test, self.btn_peers, self.btn_health]:
            b.setVisible(authed)
        can_connect = not need_elev or pw_set
        self.btn_connect.setEnabled(can_connect)
        btn_text = "Connect" if not wg_up else "Reconnect"
        if not can_connect:
            btn_text = "Connect (elevate first)"
        self.btn_connect.setText(btn_text)
        self.btn_connect.setVisible(authed)
        self.btn_disconnect.setVisible(wg_up)

        if not wg_up and self.client.connected:
            self.client.connected = False

        if wg_up:
            label, color = "VPN CONNECTED", "#27ae60"
        elif self.client.connected and not wg_up:
            label, color = "CONFIGURED (no VPN)", "#e67e22"
        elif authed:
            label, color = "LOGGED IN", "#f39c12"
        else:
            label, color = "DISCONNECTED", "#e74c3c"

        my_ip = s.get('assigned_ip') or ''
        self.s_con.setText(label)
        self.s_con.setStyleSheet(f"font-weight: bold; color: {color};")
        self.s_self.setText(f"Me: {my_ip}" if my_ip else "")
        self.s_host.setText(f"GW: 172.255.0.1")
        self.s_priv.setText("🔓" if need_elev else "🔐" if authed else "")
        self.s_user.setText(f"User: {self.client.username or ''}")
        self.s_ip.setText(f"IP: {my_ip}")
        self.s_peers.setText(f"Peers: {s.get('peer_count') or 0}")

        wg_str = ""
        if wg_info and wg_info.get("up"):
            for l in (wg_info.get("output", "") or "").split("\n"):
                if "transfer:" in l.lower() or "handshake:" in l.lower():
                    wg_str += (" | " if wg_str else "") + l.strip()
        self.s_wg.setText(wg_str)

        # Build table: own device + host + peers
        own_public = self.client.keypair.get("public_key", "") if self.client.keypair else ""
        rows = []
        # Row 0: own device
        rows.append({"host": "me", "ip": my_ip, "name": f"{self.client.username or 'Me'} (you)", "key": own_public[:22] + "...", "highlight": True})
        # Row 1: host gateway
        rows.append({"host": "●", "ip": "172.255.0.1", "name": "BirdWing Host (gateway)", "key": (self.client.host_public_key or "")[:22] + "...", "active": True, "highlight": False})
        # Remaining: peers
        for p in self.client.peers:
            if p.get("host"):
                continue
            active = p.get("active", False)
            rows.append({"host": "●" if active else "○", "ip": p["assigned_ip"], "name": p["device_name"][:35], "key": p["public_key"][:22] + "...", "active": active, "highlight": False})

        self.peers_table.setRowCount(len(rows))
        for i, r in enumerate(rows):
            host_item = QTableWidgetItem(r["host"])
            ip_item = QTableWidgetItem(r["ip"])
            name_item = QTableWidgetItem(r["name"])
            key_item = QTableWidgetItem(r["key"])
            if r.get("active"):
                for item in (host_item, ip_item, name_item, key_item):
                    item.setForeground(QColor("#10b981"))
            elif not wg_up and not r["highlight"]:
                for item in (host_item, ip_item, name_item, key_item):
                    item.setForeground(QColor("#64748b"))
            if r["highlight"] and my_ip:
                for item in (host_item, ip_item, name_item, key_item):
                    item.setForeground(QColor("#38bdf8"))
                    font = item.font()
                    font.setBold(True)
                    item.setFont(font)
            self.peers_table.setItem(i, 0, host_item)
            self.peers_table.setItem(i, 1, ip_item)
            self.peers_table.setItem(i, 2, name_item)
            self.peers_table.setItem(i, 3, key_item)

        pm = QPixmap(16, 16)
        pm.fill(Qt.green if wg_up else Qt.gray)
        self.tray.setIcon(QIcon(pm))

        if authed and not self._initialized:
            self._initialized = True
            self._bg("init", lambda: (self.client.fetch_settings(), self.client.fetch_peers()))

    def _bg(self, name, fn, cb=None):
        if name in self._tasks and self._tasks[name]["thread"].is_alive():
            return

        def go():
            try:
                result = fn()
                if cb:
                    cb(result)
            except Exception as e:
                self._on_log(f"[!] {name}: {e}")
            finally:
                self._tasks.pop(name, None)

        t = threading.Thread(target=go, daemon=True)
        self._tasks[name] = {"thread": t, "cb": cb}
        t.start()

    def _do_connect(self):
        def connect_fn():
            self._on_log("[*] Connecting...")
            wg_up = False
            if is_wireguard_installed():
                wg_info = get_interface_status(self.client.interface_name)
                wg_up = wg_info.get("up", False)
            if wg_up:
                self.client.disconnect()
            result = self.client.connect()
            if not result:
                err = getattr(self.client, '_last_error', None)
                if err == "auth":
                    self._on_log("[!] Session expired — please login again")
                    QTimer.singleShot(0, self._do_login)
                elif err == "elevation":
                    self._on_log("[!] VPN connection failed: need admin privileges")
                elif err == "no_wg":
                    pass
                else:
                    self._on_log("[!] VPN connection failed")
            return result

        self._bg("connect", connect_fn)

    def _do_sudo(self):
        dlg = QDialog(self)
        dlg.setWindowTitle("Privilege Elevation")
        dlg.setMinimumWidth(420)
        layout = QFormLayout(dlg)
        label = QLabel("WireGuard commands require admin rights.\nEnter your sudo password to continue:")
        label.setWordWrap(True)
        layout.addRow(label)

        pw_inp = QLineEdit()
        pw_inp.setEchoMode(QLineEdit.Password)
        layout.addRow("Password:", pw_inp)

        err_label = QLabel("")
        err_label.setStyleSheet("color: #e74c3c; font-weight: bold;")
        err_label.setVisible(False)
        layout.addRow(err_label)

        btns = QHBoxLayout()
        ok = QPushButton("Elevate")
        cx = QPushButton("Cancel")
        btns.addWidget(ok)
        btns.addWidget(cx)
        layout.addRow(btns)
        pw_inp.setFocus()

        def on_ok():
            pw = pw_inp.text()
            if not pw:
                return
            self._on_log(f"[*] Validating sudo password ({len(pw)} chars)...")
            ok.setEnabled(False)
            cx.setEnabled(False)
            try:
                result = subprocess.run(
                    ["sudo", "-S", "true"],
                    input=pw + "\n", capture_output=True, text=True, timeout=10
                )
                stderr_lower = (result.stderr or "").lower()
                if result.returncode == 0:
                    set_sudo_password(pw)
                    self._on_log("[+] Privilege elevation active (session only)")
                    dlg.accept()
                elif "incorrect password" in stderr_lower or "sorry" in stderr_lower:
                    err_label.setText("Incorrect password. Try again.")
                    err_label.setVisible(True)
                    pw_inp.clear()
                    pw_inp.setFocus()
                    ok.setEnabled(True)
                    cx.setEnabled(True)
                else:
                    err_label.setText(f"sudo error (exit {result.returncode})")
                    self._on_log(f"[!] sudo stderr: {(result.stderr or '').strip()}")
                    err_label.setVisible(True)
                    pw_inp.clear()
                    pw_inp.setFocus()
                    ok.setEnabled(True)
                    cx.setEnabled(True)
            except subprocess.TimeoutExpired:
                err_label.setText("Password validation timed out. Try again.")
                err_label.setVisible(True)
                pw_inp.clear()
                pw_inp.setFocus()
                ok.setEnabled(True)
                cx.setEnabled(True)
            except Exception as e:
                err_label.setText(f"Validation error: {e}")
                err_label.setVisible(True)
                pw_inp.clear()
                pw_inp.setFocus()
                ok.setEnabled(True)
                cx.setEnabled(True)

        ok.clicked.connect(on_ok)
        cx.clicked.connect(dlg.reject)
        dlg.exec()

    def _do_login(self):
        dlg = QDialog(self)
        dlg.setWindowTitle("Login")
        layout = QFormLayout(dlg)
        u_inp = QLineEdit()
        p_inp = QLineEdit()
        p_inp.setEchoMode(QLineEdit.Password)
        layout.addRow("Username:", u_inp)
        layout.addRow("Password:", p_inp)
        err_label = QLabel("")
        err_label.setStyleSheet("color: #e74c3c; font-weight: bold;")
        err_label.setVisible(False)
        layout.addRow(err_label)
        btns = QHBoxLayout()
        ok = QPushButton("OK"); cx = QPushButton("Cancel")
        ok.clicked.connect(dlg.accept)
        cx.clicked.connect(dlg.reject)
        btns.addWidget(ok); btns.addWidget(cx)
        layout.addRow(btns)
        u_inp.setFocus()
        if dlg.exec() != QDialog.Accepted:
            return
        # Capture values BEFORE starting thread (no Qt cross-thread access)
        username = u_inp.text()
        password = p_inp.text()
        if not username or not password:
            return
        def login_fn():
            ok = self.client.login(username, password)
            if not ok:
                self._on_log(f"[!] Login failed — check credentials")
        self._bg("login", login_fn)

    def _do_logout(self):
        self.client.token = None
        self.client.username = None
        self.client.peers = []
        self.client._save()
        self._on_log("[+] Logged out")

    def _do_test(self):
        def run():
            self._on_log("[*] Testing WireGuard connection...")
            r = self.client.test_connection()
            for line in r["details"].strip().split("\n"):
                if line.strip():
                    self._on_log(f"  {line.strip()}")
            if r["reachable"]:
                self._on_log(f"[+] Ping OK ({r['latency_ms']}ms) — VPN tunnel is working")
            else:
                self._on_log("[!] Ping FAILED — VPN tunnel may not be established")
        self._bg("test", run)

    def _do_peers(self):
        def run():
            self.client.fetch_peers()
            lines = []
            for p in self.client.peers:
                host = "★" if p.get("host") else " "
                lines.append(f"  {host} {p['assigned_ip']:<16} {p['device_name'][:30]} {p['public_key'][:20]}...")
            self._on_log(f"Peers ({len(self.client.peers)}):\n" + "\n".join(lines))
        self._bg("peers", run)

    def _do_health(self):
        def run():
            h = self.client.health_check()
            self._on_log(f"Server: {h.get('status')} v{h.get('version')} | "
                         f"{h.get('users', '?')} users, {h.get('devices', '?')} devices")
        self._bg("health", run)

    # ─── Admin Features ────────────────────────────────────────

    def _admin_api(self, method, path):
        """Call an admin API endpoint and log the result."""
        resp = self.client._api(method, path)
        if resp and resp.status_code == 200:
            self._on_log(f"[+] {path}: OK")
            try:
                data = resp.json()
                for k, v in data.items():
                    if k not in ("status", "methods", "results"):
                        continue
                    if isinstance(v, list):
                        for item in v:
                            self._on_log(f"  {item}")
                    elif isinstance(v, dict):
                        for sk, sv in v.items():
                            self._on_log(f"  {sk}: {sv}")
                    else:
                        self._on_log(f"  {k}: {v}")
            except Exception:
                self._on_log(f"  {resp.text[:300]}")
        else:
            detail = "unknown"
            if resp:
                try:
                    detail = resp.json().get("detail", str(resp.status_code))
                except Exception:
                    detail = resp.text[:100]
            self._on_log(f"[!] {path}: {detail}")
        return resp

    def _do_admin_logs(self):
        def run():
            resp = self._admin_api("GET", "/api/admin/logs")
            if resp and resp.status_code == 200:
                data = resp.json()
                files = data.get("files", [])
                if not files:
                    self._on_log("  No log files found")
                    return
                # Show first log file content
                first = files[0]["name"]
                self._on_log(f"  Latest log: {first} ({files[0]['size']} bytes)")
                resp2 = self.client._api("GET", f"/api/admin/logs/{first}")
                if resp2 and resp2.status_code == 200:
                    lines = resp2.text.splitlines()
                    for line in lines[-30:]:
                        self._on_log(f"  {line}")
        self._bg("admin_logs", run)

    def _do_admin_wg(self):
        def run():
            self._on_log("[*] Fetching WireGuard status...")
            resp = self.client._api("GET", "/api/admin/wg-diag")
            if resp and resp.status_code == 200:
                data = resp.json()
                wg_show = data.get("wg_show", "")
                if wg_show:
                    for line in wg_show.splitlines():
                        self._on_log(f"  {line}")
                else:
                    self._on_log("  (no wg show output)")
                if data.get("last_wg_error"):
                    self._on_log(f"  Last error: {data['last_wg_error'][:200]}")
                # Service status
                for k in data:
                    if k.startswith("svc_"):
                        svc = data[k] or ""
                        m = re.search(r"STATE\s*:\s*\d+\s+(\w+)", svc)
                        state = m.group(1) if m else "unknown"
                        self._on_log(f"  {k}: {state}")
            else:
                self._on_log("[!] Failed (admin only)")
        self._bg("admin_wg", run)

    def _refresh_admin_buttons(self):
        authed = bool(self.client.token)
        self.btn_my_devices.setVisible(authed)
        for b in [self.btn_change_pass, self.btn_admin_users, self.btn_admin_devices, self.btn_admin_logs, self.btn_admin_wg, self.btn_admin_wg_stop, self.btn_admin_wg_start]:
            b.setVisible(authed)

    def _do_my_devices(self):
        """Show dialog listing the current user's devices with delete option."""
        dlg = QDialog(self)
        dlg.setWindowTitle("My Devices")
        dlg.setStyleSheet(self.styleSheet())
        dlg.setMinimumSize(650, 350)
        layout = QVBoxLayout(dlg)
        layout.setSpacing(12)

        resp = self.client._api("GET", "/api/devices/my-info")
        if not resp or resp.status_code != 200:
            QMessageBox.warning(self, "Error", "Failed to load devices")
            return
        data = resp.json()
        devices = data.get("devices", [])

        info_lbl = QLabel(f"User: {data.get('user', '?')}  |  Role: {data.get('role', '?')}  |  Devices: {len(devices)}")
        info_lbl.setStyleSheet("color: #94a3b8; padding: 4px 0;")
        layout.addWidget(info_lbl)

        table = QTableWidget(len(devices), 4)
        table.setHorizontalHeaderLabels(["Name", "Public Key", "IP", "Status"])
        table.horizontalHeader().setStretchLastSection(True)
        table.setSelectionBehavior(QTableWidget.SelectRows)
        table.setEditTriggers(QTableWidget.NoEditTriggers)
        table.verticalHeader().setVisible(False)

        for i, d in enumerate(devices):
            table.setItem(i, 0, QTableWidgetItem(d.get("device_name", "")))
            pk = d.get("public_key", "")
            table.setItem(i, 1, QTableWidgetItem(pk[:24] + "..." if len(pk) > 24 else pk))
            table.setItem(i, 2, QTableWidgetItem(d.get("assigned_ip", "")))
            enabled = d.get("enabled", 0)
            s_item = QTableWidgetItem("Active" if enabled else "Disabled")
            s_item.setForeground(QColor("#10b981" if enabled else "#94a3b8"))
            table.setItem(i, 3, s_item)

        layout.addWidget(table)

        btns = QHBoxLayout()
        register_btn = QPushButton("+ Register New Device")
        register_btn.clicked.connect(lambda: (dlg.close(), self._do_register_device()))
        btns.addWidget(register_btn)
        btns.addStretch()
        close_btn = QPushButton("Close")
        close_btn.clicked.connect(dlg.accept)
        btns.addWidget(close_btn)
        layout.addLayout(btns)
        dlg.exec()

    def _do_register_device(self):
        """Show dialog to register a new device."""
        dlg = QDialog(self)
        dlg.setWindowTitle("Register Device")
        dlg.setStyleSheet(self.styleSheet())
        dlg.setMinimumWidth(400)
        layout = QFormLayout(dlg)
        layout.setSpacing(10)

        name_inp = QLineEdit()
        name_inp.setPlaceholderText("e.g. My Laptop")
        layout.addRow("Device Name:", name_inp)

        self.client.ensure_keypair()
        pk = self.client.keypair["public_key"] if self.client.keypair else ""
        pk_inp = QLineEdit(pk)
        pk_inp.setPlaceholderText("WireGuard public key")
        layout.addRow("Public Key:", pk_inp)

        use_det = QPushButton("Use deterministic key")
        use_det.setCheckable(True)
        use_det.setChecked(True)
        use_det.setStyleSheet("QPushButton { background: #1e293b; color: #94a3b8; border: 1px solid #334155; border-radius: 4px; padding: 4px 10px; font-size: 11px; } QPushButton:checked { background: #059669; color: #fff; border-color: #059669; }")
        layout.addRow(use_det)

        err = QLabel("")
        err.setStyleSheet("color: #ef4444;")
        err.setVisible(False)
        layout.addRow(err)

        btns = QHBoxLayout()
        ok_btn = QPushButton("Register")
        cx_btn = QPushButton("Cancel")
        btns.addWidget(ok_btn); btns.addWidget(cx_btn)
        layout.addRow(btns)

        def on_use_det():
            if use_det.isChecked():
                self.client.ensure_keypair()
                pk_inp.setText(self.client.keypair["public_key"] if self.client.keypair else "")
                pk_inp.setEnabled(False)
            else:
                pk_inp.setEnabled(True)
        use_det.clicked.connect(on_use_det)
        on_use_det()

        def on_ok():
            name = name_inp.text()
            pubkey = pk_inp.text()
            if not name or not pubkey:
                err.setText("Fill all fields"); err.setVisible(True); return
            resp = self.client._api("POST", "/api/devices/register", {"device_name": name, "public_key": pubkey})
            if resp and resp.status_code == 200:
                self._on_log(f"[+] Device '{name}' registered ({resp.json().get('assigned_ip', '?')})")
                dlg.accept()
            else:
                detail = "Registration failed"
                if resp:
                    try: detail = resp.json().get("detail", detail)
                    except: pass
                err.setText(detail); err.setVisible(True)

        ok_btn.clicked.connect(on_ok)
        cx_btn.clicked.connect(dlg.reject)
        dlg.exec()

    def _show_admin_dialog(self, title, html_body_fn):
        """Show an admin dialog that fetches data and populates a table."""
        dlg = QDialog(self)
        dlg.setWindowTitle(title)
        dlg.setMinimumSize(600, 400)
        dlg.resize(700, 500)
        dlg.setStyleSheet(self.styleSheet())
        layout = QVBoxLayout(dlg)
        data = html_body_fn()
        if not data:
            QMessageBox.warning(self, "Error", "Failed to load data")
            return
        table = data.get("table", QTableWidget())
        layout.addWidget(table)
        layout.addStretch()
        btns = QHBoxLayout()
        ok = QPushButton("Close")
        ok.clicked.connect(dlg.accept)
        btns.addStretch()
        btns.addWidget(ok)
        layout.addLayout(btns)
        dlg.exec()

    def _do_change_password(self):
        dlg = QDialog(self)
        dlg.setWindowTitle("Change Password")
        dlg.setMinimumWidth(400)
        layout = QFormLayout(dlg)

        old_inp = QLineEdit()
        old_inp.setEchoMode(QLineEdit.Password)
        new_inp = QLineEdit()
        new_inp.setEchoMode(QLineEdit.Password)
        confirm_inp = QLineEdit()
        confirm_inp.setEchoMode(QLineEdit.Password)

        layout.addRow("Current Password:", old_inp)
        layout.addRow("New Password:", new_inp)
        layout.addRow("Confirm:", confirm_inp)

        err = QLabel("")
        err.setStyleSheet("color: #ef4444; font-weight: bold;")
        err.setVisible(False)
        layout.addRow(err)

        btns = QHBoxLayout()
        ok = QPushButton("Change")
        cx = QPushButton("Cancel")
        btns.addWidget(ok); btns.addWidget(cx)
        layout.addRow(btns)

        def on_ok():
            old = old_inp.text()
            new = new_inp.text()
            confirm = confirm_inp.text()
            if not old or not new:
                err.setText("Fill all fields"); err.setVisible(True)
                return
            if new != confirm:
                err.setText("Passwords do not match"); err.setVisible(True)
                return
            if len(new) < 4:
                err.setText("Password too short"); err.setVisible(True)
                return
            resp = self.client._api("PATCH", "/api/auth/password", {"old_password": old, "new_password": new})
            if resp and resp.status_code == 200:
                self._on_log("[+] Password changed")
                dlg.accept()
            else:
                err.setText("Current password is incorrect" if (resp and resp.status_code == 403) else "Failed to change password")
                err.setVisible(True)

        ok.clicked.connect(on_ok)
        cx.clicked.connect(dlg.reject)
        dlg.exec()

    def _do_admin_users(self):
        dlg = QDialog(self)
        dlg.setWindowTitle("Manage Users")
        dlg.setMinimumSize(650, 450)
        dlg.setStyleSheet(self.styleSheet())
        layout = QVBoxLayout(dlg)

        resp = self.client._api("GET", "/api/admin/users")
        if not resp or resp.status_code != 200:
            QMessageBox.warning(self, "Error", "Failed to load users (admin only)")
            return
        users = resp.json().get("users", [])

        table = QTableWidget(len(users), 5)
        table.setHorizontalHeaderLabels(["ID", "Username", "Role", "Created", "Actions"])
        table.horizontalHeader().setStretchLastSection(True)
        table.setSelectionBehavior(QTableWidget.SelectRows)
        table.setEditTriggers(QTableWidget.NoEditTriggers)
        table.verticalHeader().setVisible(False)

        for i, u in enumerate(users):
            table.setItem(i, 0, QTableWidgetItem(str(u["id"])))
            table.setItem(i, 1, QTableWidgetItem(u["username"]))
            table.setItem(i, 2, QTableWidgetItem(u["role"]))
            table.setItem(i, 3, QTableWidgetItem(u.get("created_at", "")))

            # Action buttons
            act_widget = QWidget()
            act_lay = QHBoxLayout(act_widget)
            act_lay.setContentsMargins(2, 2, 2, 2)
            act_lay.setSpacing(4)

            pw_btn = QPushButton("Set Pass")
            pw_btn.setStyleSheet("background: #334155; color: #e2e8f0; border: 1px solid #475569; border-radius: 4px; padding: 3px 8px; font-size: 11px;")
            uid = u["id"]
            uname = u["username"]
            pw_btn.clicked.connect(lambda checked, x=uid, n=uname: self._admin_set_password(x, n))
            act_lay.addWidget(pw_btn)

            if u["role"] != "admin" or sum(1 for x in users if x["role"] == "admin") > 1:
                del_btn = QPushButton("Delete")
                del_btn.setStyleSheet("background: #450a0a; color: #fca5a5; border: 1px solid #7f1d1d; border-radius: 4px; padding: 3px 8px; font-size: 11px;")
                del_btn.clicked.connect(lambda checked, x=uid, n=uname: self._admin_delete_user(x, n))
                act_lay.addWidget(del_btn)

            act_lay.addStretch()
            table.setCellWidget(i, 4, act_widget)

        layout.addWidget(table)
        btns = QHBoxLayout()
        add_btn = QPushButton("+ Add User")
        add_btn.clicked.connect(lambda: (dlg.close(), self._do_admin_add_user()))
        btns.addWidget(add_btn)
        btns.addStretch()
        close_btn = QPushButton("Close")
        close_btn.clicked.connect(dlg.accept)
        btns.addWidget(close_btn)
        layout.addLayout(btns)
        dlg.exec()

    def _admin_set_password(self, user_id, username):
        dlg = QDialog(self)
        dlg.setWindowTitle(f"Set Password for {username}")
        dlg.setMinimumWidth(350)
        layout = QFormLayout(dlg)
        new_inp = QLineEdit()
        new_inp.setEchoMode(QLineEdit.Password)
        confirm_inp = QLineEdit()
        confirm_inp.setEchoMode(QLineEdit.Password)
        layout.addRow("New Password:", new_inp)
        layout.addRow("Confirm:", confirm_inp)
        err = QLabel("")
        err.setStyleSheet("color: #ef4444;")
        err.setVisible(False)
        layout.addRow(err)
        btns = QHBoxLayout()
        ok = QPushButton("Set")
        cx = QPushButton("Cancel")
        btns.addWidget(ok); btns.addWidget(cx)
        layout.addRow(btns)
        def on_ok():
            new = new_inp.text()
            confirm = confirm_inp.text()
            if not new: err.setText("Enter a password"); err.setVisible(True); return
            if new != confirm: err.setText("Passwords do not match"); err.setVisible(True); return
            if len(new) < 4: err.setText("Password too short"); err.setVisible(True); return
            resp = self.client._api("PATCH", f"/api/admin/users/{user_id}/password", {"new_password": new})
            if resp and resp.status_code == 200:
                self._on_log(f"[+] Password set for {username}")
                dlg.accept()
            else: err.setText("Failed"); err.setVisible(True)
        ok.clicked.connect(on_ok)
        cx.clicked.connect(dlg.reject)
        dlg.exec()

    def _admin_delete_user(self, user_id, username):
        reply = QMessageBox.question(self, "Delete User",
            f'Delete user "{username}" and all their devices?\nThis cannot be undone.',
            QMessageBox.Yes | QMessageBox.No)
        if reply != QMessageBox.Yes:
            return
        resp = self.client._api("DELETE", f"/api/admin/users/{user_id}")
        if resp and resp.status_code == 200:
            self._on_log(f"[+] User {username} deleted")
        else:
            detail = "Failed to delete user"
            if resp:
                try: detail = resp.json().get("detail", detail)
                except: pass
            QMessageBox.warning(self, "Error", detail)

    def _do_admin_add_user(self):
        dlg = QDialog(self)
        dlg.setWindowTitle("Add User")
        dlg.setMinimumWidth(350)
        layout = QFormLayout(dlg)
        name_inp = QLineEdit()
        pass_inp = QLineEdit(); pass_inp.setEchoMode(QLineEdit.Password)
        role_sel = QComboBox(); role_sel.addItems(["user", "admin"])
        layout.addRow("Username:", name_inp)
        layout.addRow("Password:", pass_inp)
        layout.addRow("Role:", role_sel)
        err = QLabel("")
        err.setStyleSheet("color: #ef4444;")
        err.setVisible(False)
        layout.addRow(err)
        btns = QHBoxLayout()
        ok = QPushButton("Add")
        cx = QPushButton("Cancel")
        btns.addWidget(ok); btns.addWidget(cx)
        layout.addRow(btns)
        def on_ok():
            if not name_inp.text() or not pass_inp.text():
                err.setText("Fill all fields"); err.setVisible(True); return
            resp = self.client._api("POST", "/api/admin/users", {
                "username": name_inp.text(), "password": pass_inp.text(), "role": role_sel.currentText()
            })
            if resp and resp.status_code == 200:
                self._on_log(f"[+] User {name_inp.text()} created")
                dlg.accept()
            else:
                detail = "Failed"
                if resp:
                    try: detail = resp.json().get("detail", detail)
                    except: pass
                err.setText(detail); err.setVisible(True)
        ok.clicked.connect(on_ok)
        cx.clicked.connect(dlg.reject)
        dlg.exec()

    def _do_admin_devices(self):
        dlg = QDialog(self)
        dlg.setWindowTitle("Manage All Devices")
        dlg.setMinimumSize(750, 450)
        dlg.setStyleSheet(self.styleSheet())
        layout = QVBoxLayout(dlg)

        resp = self.client._api("GET", "/api/admin/devices")
        if not resp or resp.status_code != 200:
            QMessageBox.warning(self, "Error", "Failed to load devices (admin only)")
            return
        devices = resp.json().get("devices", [])

        table = QTableWidget(len(devices), 6)
        table.setHorizontalHeaderLabels(["Name", "IP", "User", "Status", "Key", "Actions"])
        table.horizontalHeader().setStretchLastSection(True)
        table.setSelectionBehavior(QTableWidget.SelectRows)
        table.setEditTriggers(QTableWidget.NoEditTriggers)
        table.verticalHeader().setVisible(False)

        for i, d in enumerate(devices):
            is_host = d.get("host", False)
            name = d.get("device_name", "")
            table.setItem(i, 0, QTableWidgetItem("★ " + name if is_host else name))
            table.setItem(i, 1, QTableWidgetItem(d.get("assigned_ip", "")))
            table.setItem(i, 2, QTableWidgetItem(d.get("username", "")))
            enabled = d.get("enabled", 0)
            status_item = QTableWidgetItem("Active" if enabled else "Disabled")
            status_item.setForeground(QColor("#22c55e" if enabled else "#94a3b8"))
            table.setItem(i, 3, status_item)
            pk = d.get("public_key", "")
            table.setItem(i, 4, QTableWidgetItem(pk[:20] + "..." if pk else ""))

            if not is_host:
                act_w = QWidget()
                act_l = QHBoxLayout(act_w)
                act_l.setContentsMargins(2, 2, 2, 2)
                act_l.setSpacing(4)
                did = d["id"]
                toggle_btn = QPushButton("Disable" if enabled else "Enable")
                toggle_btn.setStyleSheet(f"background: {'#450a0a' if enabled else '#052e16'}; color: {'#fca5a5' if enabled else '#86efac'}; border: 1px solid {'#7f1d1d' if enabled else '#166534'}; border-radius: 4px; padding: 3px 8px; font-size: 11px;")
                toggle_btn.clicked.connect(lambda checked, x=did, en=int(enabled): (
                    self._bg(f"toggle_{x}", lambda: self.client._api("PATCH", f"/api/admin/devices/{x}/toggle", {"enabled": not en})),
                    toggle_btn.setVisible(False)
                ))
                act_l.addWidget(toggle_btn)

                rename_btn = QPushButton("Rename")
                rename_btn.setStyleSheet("background: #334155; color: #e2e8f0; border: 1px solid #475569; border-radius: 4px; padding: 3px 8px; font-size: 11px;")
                rename_btn.clicked.connect(lambda checked, x=did, n=name: self._admin_rename_device(x, n))
                act_l.addWidget(rename_btn)

                del_btn = QPushButton("Delete")
                del_btn.setStyleSheet("background: #450a0a; color: #fca5a5; border: 1px solid #7f1d1d; border-radius: 4px; padding: 3px 8px; font-size: 11px;")
                del_btn.clicked.connect(lambda checked, x=did: self._admin_delete_device(x))
                act_l.addWidget(del_btn)

                act_l.addStretch()
                table.setCellWidget(i, 5, act_w)
            else:
                table.setCellWidget(i, 5, QLabel("—"))

        layout.addWidget(table)
        btns = QHBoxLayout()
        close_btn = QPushButton("Close")
        close_btn.clicked.connect(dlg.accept)
        btns.addStretch()
        btns.addWidget(close_btn)
        layout.addLayout(btns)
        dlg.exec()

    def _admin_rename_device(self, device_id, current_name):
        dlg = QDialog(self)
        dlg.setWindowTitle("Rename Device")
        dlg.setMinimumWidth(350)
        layout = QFormLayout(dlg)
        name_inp = QLineEdit(current_name)
        layout.addRow("Device Name:", name_inp)
        btns = QHBoxLayout()
        ok = QPushButton("Rename")
        cx = QPushButton("Cancel")
        btns.addWidget(ok); btns.addWidget(cx)
        layout.addRow(btns)
        def on_ok():
            new_name = name_inp.text()
            if not new_name: return
            resp = self.client._api("PATCH", f"/api/admin/devices/{device_id}/name", {"device_name": new_name})
            if resp and resp.status_code == 200:
                self._on_log(f"[+] Device {device_id} renamed to '{new_name}'")
                dlg.accept()
            else:
                QMessageBox.warning(self, "Error", "Failed to rename device")
        ok.clicked.connect(on_ok)
        cx.clicked.connect(dlg.reject)
        dlg.exec()

    def _admin_delete_device(self, device_id):
        reply = QMessageBox.question(self, "Delete Device",
            "Delete this device?",
            QMessageBox.Yes | QMessageBox.No)
        if reply != QMessageBox.Yes:
            return
        resp = self.client._api("DELETE", f"/api/admin/devices/{device_id}")
        if resp and resp.status_code == 200:
            self._on_log(f"[+] Device {device_id} deleted")
        else:
            QMessageBox.warning(self, "Error", "Failed to delete device")


def run_qt(client):
    app = QApplication.instance() or QApplication(sys.argv)
    app.setApplicationName("BirdWing")
    app.setQuitOnLastWindowClosed(False)
    w = BirdWingQtApp(client)
    w.show()
    app.exec()
    app = QApplication.instance() or QApplication(sys.argv)
    app.setApplicationName("BirdWing")
    app.setQuitOnLastWindowClosed(False)
    w = BirdWingQtApp(client)
    w.show()
    app.exec()


# ===== Textual TUI =====
"""
BirdWing Textual TUI
"""

import threading
import time
import subprocess

from textual.app import App, ComposeResult
from textual.containers import Container, Horizontal, Vertical, VerticalScroll
from textual.widgets import Button, DataTable, Footer, Input, Label, RichLog, Static



class BirdWingTUI(App):
    CSS = """
    Screen {
        background: $surface;
    }
    #title-bar {
        height: 3;
        background: $primary;
        color: $text;
        content-align: center middle;
        text-style: bold;
    }
    #main-container {
        layout: horizontal;
        height: 1fr;
        min-height: 10;
    }
    #actions {
        width: 26;
        padding: 1 1 0 1;
        border: solid $primary;
        margin: 1;
        height: 100%;
    }
    #actions > Label {
        margin: 0 0 1 0;
        text-style: bold;
    }
    #actions > Button {
        margin: 0 0 1 0;
        width: 100%;
    }
    #right-panel {
        width: 1fr;
        padding: 1;
        margin: 1 1 1 0;
        border: solid $primary;
        height: 100%;
    }
    #peers-table {
        height: 100%;
    }
    #status-bar {
        height: 3;
        background: $panel;
        padding: 0 1;
    }
    #status-bar > Label {
        padding: 0 2;
    }
    #log-panel {
        height: 8;
        border: solid $primary;
        margin: 0 1 1 1;
    }
    #login-form {
        display: none;
        height: auto;
        margin: 0 1 1 1;
        border: solid $accent;
        padding: 1;
    }
    #login-form.visible {
        display: block;
    }
    #sudo-form {
        display: none;
        height: auto;
        margin: 0 1 1 1;
        border: solid $warning;
        padding: 1;
    }
    #sudo-form.visible {
        display: block;
    }
    Button.hidden {
        display: none;
    }
    """

    def __init__(self, client):
        super().__init__()
        self.client = client
        self.client.log_callback = self._on_log
        self._timer = None
        self._tasks = {}

    def compose(self):
        yield Static("BirdWing Mesh VPN", id="title-bar")
        with Container(id="main-container"):
            with Vertical(id="actions"):
                yield Label("Actions")
                yield Button("Login", id="btn-login", variant="primary")
                yield Button("Logout", id="btn-logout", variant="default")
                yield Button("Elevate", id="btn-sudo", variant="warning")
                yield Button("Connect", id="btn-connect", variant="success")
                yield Button("Disconnect", id="btn-disconnect", variant="warning")
                yield Button("Test", id="btn-test")
                yield Button("Peers", id="btn-peers")
                yield Button("Health", id="btn-health")
                yield Button("Reset", id="btn-reset", variant="error")
                yield Button("Quit", id="btn-quit")
            with Vertical(id="right-panel"):
                yield DataTable(id="peers-table")
        with Container(id="status-bar"):
            yield Label("Disconnected", id="s-con")
            yield Label("", id="s-user")
            yield Label("", id="s-ip")
            yield Label("", id="s-peers")
            yield Label("", id="s-wg")
        with VerticalScroll(id="login-form"):
            yield Label("Login")
            yield Input(placeholder="Username", id="login-user")
            yield Input(placeholder="Password", password=True, id="login-pass")
            with Horizontal():
                yield Button("Submit", id="login-submit", variant="primary")
                yield Button("Cancel", id="login-cancel")
        with VerticalScroll(id="sudo-form"):
            yield Label("Privilege Elevation")
            yield Label("WireGuard commands require admin rights.", id="sudo-label")
            yield Input(placeholder="Sudo password", password=True, id="sudo-pass")
            with Horizontal():
                yield Button("Elevate", id="sudo-submit", variant="warning")
                yield Button("Cancel", id="sudo-cancel")
        yield RichLog(id="log-panel", markup=True, max_lines=200)
        yield Footer()

    def on_mount(self):
        self.title = "BirdWing VPN"
        self.query_one("#peers-table", DataTable).add_columns("Host", "IP", "Device", "Key")
        self._log_line("BirdWing TUI ready")
        self._refresh_ui()
        self._timer = self.set_interval(3, self._refresh_ui)

    def _refresh_ui(self):
        try:
            s = self.client.get_status()
        except Exception:
            return

        authed = bool(self.client.token)
        connected = s["connection"] == "connected"

        # Status bar
        con_lbl = self.query_one("#s-con", Label)
        con_lbl.update("CONNECTED" if connected else "DISCONNECTED" if authed else "NOT LOGGED IN")
        con_lbl.styles.color = "green" if connected else "yellow" if authed else "red"

        self.query_one("#s-user", Label).update(f"User: {self.client.username or ''}")
        self.query_one("#s-ip", Label).update(f"IP: {s.get('assigned_ip') or ''}")
        self.query_one("#s-peers", Label).update(f"Peers: {s.get('peer_count') or 0}")

        wg = ""
        if s.get("wg") and s["wg"].get("up"):
            for l in s["wg"].get("output", "").split("\n"):
                if "transfer:" in l.lower() or "handshake:" in l.lower():
                    wg += (" | " if wg else "") + l.strip()
        self.query_one("#s-wg", Label).update(wg)

        # Button visibility based on state
        elevated = is_elevated()
        need_elev = authed and needs_elevation()
        self.query_one("#btn-login", Button).display = not authed
        self.query_one("#btn-logout", Button).display = authed
        self.query_one("#btn-sudo", Button).display = need_elev
        self.query_one("#btn-sudo", Button).label = "Elevated ✓" if elevated else "Elevate"
        self.query_one("#btn-connect", Button).display = authed and not connected
        self.query_one("#btn-disconnect", Button).display = authed and connected
        self.query_one("#btn-test", Button).display = authed
        self.query_one("#btn-peers", Button).display = authed
        self.query_one("#btn-health", Button).display = authed

        # Peers table
        table = self.query_one("#peers-table", DataTable)
        table.clear()
        for p in self.client.peers:
            h = "★" if p.get("host") else " "
            table.add_row(h, p["assigned_ip"], p["device_name"][:30], p["public_key"][:20] + "...")

        # Login form: hide if already authed
        lf = self.query_one("#login-form", VerticalScroll)
        if authed:
            lf.remove_class("visible")
        # Sudo form: hide if already elevated
        sf = self.query_one("#sudo-form", VerticalScroll)
        if elevated:
            sf.remove_class("visible")

    def _bg(self, name, fn):
        if name in self._tasks and self._tasks[name].is_alive():
            self._log_line(f"[yellow]{name} already running[/]")
            return

        def go():
            try:
                r = fn()
                if r is False:
                    self._log_line(f"[red]{name} failed[/]")
                elif isinstance(r, dict):
                    pass
            except Exception as e:
                self._log_line(f"[red]{name} error: {e}[/]")
            self._tasks.pop(name, None)

        t = threading.Thread(target=go, daemon=True)
        self._tasks[name] = t
        t.start()

    def on_button_pressed(self, e):
        bid = e.button.id

        if bid == "btn-login":
            self.query_one("#login-form", VerticalScroll).add_class("visible")
            self.query_one("#login-user", Input).focus()

        elif bid == "btn-logout":
            self.client.token = None
            self.client.username = None
            self.client.peers = []
            self.client._save()
            self._log_line("[+] Logged out")

        elif bid == "btn-sudo":
            self.query_one("#sudo-form", VerticalScroll).add_class("visible")
            self.query_one("#sudo-pass", Input).focus()

        elif bid == "btn-connect":
            def do_connect():
                r = self.client.connect()
                if not r:
                    err = getattr(self.client, '_last_error', None)
                    if err == "auth":
                        self._log_line("[red]Session expired — please login again[/]")
                    elif err == "elevation":
                        self._log_line("[red]VPN connection failed: need admin privileges[/]")
                        self.call_from_thread(lambda: self.query_one("#sudo-form", VerticalScroll).add_class("visible"))
                    else:
                        self._log_line("[red]VPN connection failed[/]")
            self._bg("connect", do_connect)

        elif bid == "btn-disconnect":
            self._bg("disconnect", self.client.disconnect)

        elif bid == "btn-test":
            def do_test():
                r = self.client.test_connection()
                for line in r["details"].strip().split("\n"):
                    if line.strip():
                        self._log_line(f"  {line}")
                if r["reachable"]:
                    self._log_line(f"[green]Ping OK ({r['latency_ms']}ms) — VPN tunnel working[/]")
                else:
                    self._log_line("[red]Ping FAILED — VPN tunnel not established[/]")
            self._bg("test", do_test)

        elif bid == "btn-peers":
            self._bg("peers", self.client.display_peers)

        elif bid == "btn-health":
            def show_health():
                h = self.client.health_check()
                self._log_line(f"Server: {h.get('status')} | v{h.get('version')} | {h.get('users', '?')} users, {h.get('devices', '?')} devices")
            self._bg("health", show_health)

        elif bid == "btn-reset":
            self._bg("reset", self.client.reset)

        elif bid == "btn-quit":
            self.exit()

        elif bid == "login-submit":
            self._do_login()
        elif bid == "login-cancel":
            self.query_one("#login-form", VerticalScroll).remove_class("visible")
        elif bid == "sudo-submit":
            self._do_sudo()
        elif bid == "sudo-cancel":
            self.query_one("#sudo-form", VerticalScroll).remove_class("visible")

    def on_input_submitted(self, e):
        if e.input.id == "login-user":
            self.query_one("#login-pass", Input).focus()
        elif e.input.id == "login-pass":
            self._do_login()
        elif e.input.id == "sudo-pass":
            self._do_sudo()

    def _do_sudo(self):
        pw = self.query_one("#sudo-pass", Input).value
        if not pw:
            self._log_line("[red]Password required[/]")
            return

        def elevate():
            try:
                result = subprocess.run(
                    ["sudo", "-S", "true"],
                    input=pw + "\n", capture_output=True, text=True, timeout=10,
                )
                if result.returncode == 0:
                    set_sudo_password(pw)
                    self._log_line("[green]Privilege elevation active (session only)[/]")
                    self.call_from_thread(self._hide_sudo)
                elif "incorrect password" in (result.stderr or "").lower() or "sorry" in (result.stderr or "").lower():
                    self._log_line("[red]Incorrect password[/]")
                else:
                    self._log_line(f"[red]sudo error (exit {result.returncode}): {(result.stderr or '').strip()}[/]")
            except Exception as e:
                self._log_line(f"[red]Elevation error: {e}[/]")
            self.call_from_thread(lambda: self.query_one("#sudo-pass", Input).clear())

        self._bg("sudo", elevate)

    def _hide_sudo(self):
        sf = self.query_one("#sudo-form", VerticalScroll)
        sf.remove_class("visible")
        self.query_one("#sudo-pass", Input).value = ""

    def _do_login(self):
        u = self.query_one("#login-user", Input).value
        p = self.query_one("#login-pass", Input).value
        if not u or not p:
            self._log_line("[red]Username and password required[/]")
            return

        def login():
            if self.client.login(u, p):
                self.call_from_thread(self._hide_login)
        self._bg("login", login)

    def _hide_login(self):
        lf = self.query_one("#login-form", VerticalScroll)
        lf.remove_class("visible")
        self.query_one("#login-user", Input).value = ""
        self.query_one("#login-pass", Input).value = ""


def run_tui(client):
    BirdWingTUI(client).run()




if __name__ == "__main__":
    main_cli()
