"""
订阅用户服务：管理订阅用户的Xray配置同步
"""
import json
from typing import List, Dict, Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from ..core.security import decrypt_string
from ..models import Node, Server, Socks5Proxy, Subscriber
from ..models.subscriber import SubscriberStatus
from .ssh_service import SSHService


class SubscriberService:
    """订阅用户服务"""
    
    def __init__(self, db: AsyncSession):
        self.db = db
        self.ssh_service = SSHService()
    
    async def get_node_clients(self, node_id: int) -> List[Dict[str, Any]]:
        """
        获取节点下所有启用的订阅用户，生成Xray clients配置
        """
        query = select(Subscriber).where(
            Subscriber.node_id == node_id,
            Subscriber.is_enabled == True,
            Subscriber.status == SubscriberStatus.ACTIVE
        )
        result = await self.db.execute(query)
        subscribers = result.scalars().all()
        
        # 获取节点信息
        node = await self.db.get(Node, node_id)
        if not node:
            return []
        
        clients = []
        
        # 如果没有订阅用户，使用节点自己的UUID作为默认用户
        if not subscribers:
            client = {
                "id": node.uuid,
                "email": f"node_{node_id}@default",
                "flow": "xtls-rprx-vision"
            }
            clients.append(client)
        else:
            # 添加所有订阅用户
            for sub in subscribers:
                client = {
                    "id": sub.uuid,
                    "email": sub.xray_email,
                    "flow": "xtls-rprx-vision"
                }
                clients.append(client)
        
        return clients
    
    async def generate_xray_config(self, node_id: int) -> Dict[str, Any]:
        """
        为节点生成完整的Xray配置（包含所有订阅用户）
        """
        node = await self.db.get(Node, node_id)
        if not node:
            raise ValueError("节点不存在")
        
        socks5 = await self.db.get(Socks5Proxy, node.socks5_id)
        if not socks5:
            raise ValueError("SK5代理不存在")
        
        # 获取所有启用的订阅用户
        clients = await self.get_node_clients(node_id)
        
        # SK5配置
        socks5_password = decrypt_string(socks5.password) if socks5.password else ""
        
        outbound = {
            "protocol": "socks",
            "settings": {
                "servers": [
                    {
                        "address": socks5.ip,
                        "port": socks5.port
                    }
                ]
            },
            "tag": "proxy"
        }
        
        if socks5.username and socks5_password:
            outbound["settings"]["servers"][0]["users"] = [
                {
                    "user": socks5.username,
                    "pass": socks5_password
                }
            ]
        
        # 入站配置（包含所有用户）
        reality_dest = node.reality_dest or "www.apple.com:443"
        server_names = node.reality_server_names.split(",") if node.reality_server_names else [reality_dest.split(":")[0]]
        
        inbound = {
            "listen": "0.0.0.0",
            "port": node.listen_port,
            "protocol": "vless",
            "settings": {
                "clients": clients,  # 多用户！
                "decryption": "none"
            },
            "streamSettings": {
                "network": "tcp",
                "security": "reality",
                "realitySettings": {
                    "show": False,
                    "dest": reality_dest,
                    "xver": 0,
                    "serverNames": server_names,
                    "privateKey": node.reality_private_key,
                    "shortIds": [node.reality_short_id]
                }
            },
            "tag": "vless-in"
        }
        
        config = {
            "log": {
                "loglevel": "warning"
            },
            "stats": {},
            "api": {
                "tag": "api",
                "services": ["StatsService"]
            },
            "policy": {
                "levels": {
                    "0": {
                        "statsUserUplink": True,
                        "statsUserDownlink": True
                    }
                },
                "system": {
                    "statsInboundUplink": True,
                    "statsInboundDownlink": True,
                    "statsOutboundUplink": True,
                    "statsOutboundDownlink": True
                }
            },
            "inbounds": [
                {
                    "listen": "127.0.0.1",
                    "port": 10085,
                    "protocol": "dokodemo-door",
                    "settings": {
                        "address": "127.0.0.1"
                    },
                    "tag": "api"
                },
                inbound
            ],
            "outbounds": [
                outbound,
                {"protocol": "freedom", "tag": "direct"},
                {"protocol": "blackhole", "tag": "block"}
            ],
            "routing": {
                "rules": [
                    {"type": "field", "inboundTag": ["api"], "outboundTag": "api"},
                    {"type": "field", "inboundTag": ["vless-in"], "outboundTag": "proxy"}
                ]
            }
        }
        
        return config
    
    async def sync_node_config(self, node_id: int) -> tuple[bool, str]:
        """
        同步节点配置到服务器（重新生成并部署）
        """
        node = await self.db.get(Node, node_id)
        if not node:
            return False, "节点不存在"
        
        server = await self.db.get(Server, node.server_id)
        if not server:
            return False, "服务器不存在"
        
        try:
            # 生成新配置
            config = await self.generate_xray_config(node_id)
            config_json = json.dumps(config, indent=2)
            
            # 部署到服务器
            server_password = decrypt_string(server.ssh_password)
            
            config_path = f"/etc/xray/config_{node_id}.json"
            service_name = f"xray-node-{node_id}"
            
            # 更新配置并重启服务
            script = f'''#!/bin/bash
cat > {config_path} << 'XRAY_CONFIG'
{config_json}
XRAY_CONFIG

systemctl restart {service_name}
echo "配置已更新并重启服务"
'''
            
            success, output = await self.ssh_service.execute_command(
                host=server.ip,
                port=server.ssh_port,
                username=server.ssh_user,
                password=server_password,
                command=script
            )
            
            if success:
                return True, "配置已同步"
            else:
                return False, f"同步失败: {output}"
                
        except Exception as e:
            return False, str(e)
    
    async def get_subscriber_traffic(self, subscriber: Subscriber) -> tuple[int, int]:
        """
        从Xray获取订阅用户的流量统计
        返回: (upload_bytes, download_bytes)
        """
        node = await self.db.get(Node, subscriber.node_id)
        if not node:
            return 0, 0
        
        server = await self.db.get(Server, node.server_id)
        if not server:
            return 0, 0
        
        server_password = decrypt_string(server.ssh_password)
        
        # 通过Xray API获取用户流量
        email = subscriber.xray_email
        
        command = f'''#!/bin/bash
# 获取用户流量统计
UPLINK=$(curl -s "http://127.0.0.1:10085/stats/query" -d '{{"name": "user>>>{email}>>>traffic>>>uplink", "reset": false}}' 2>/dev/null | grep -oP '"value":\s*\K\d+' || echo "0")
DOWNLINK=$(curl -s "http://127.0.0.1:10085/stats/query" -d '{{"name": "user>>>{email}>>>traffic>>>downlink", "reset": false}}' 2>/dev/null | grep -oP '"value":\s*\K\d+' || echo "0")

echo "$UPLINK $DOWNLINK"
'''
        
        try:
            success, output = await self.ssh_service.execute_command(
                host=server.ip,
                port=server.ssh_port,
                username=server.ssh_user,
                password=server_password,
                command=command
            )
            
            if success and output.strip():
                parts = output.strip().split()
                if len(parts) >= 2:
                    return int(parts[0]), int(parts[1])
        except:
            pass
        
        return 0, 0
    
    async def check_and_disable_exceeded(self, subscriber: Subscriber) -> bool:
        """
        检查订阅用户是否超限/过期，如果是则禁用
        返回: 是否被禁用
        """
        if subscriber.is_expired:
            subscriber.status = SubscriberStatus.EXPIRED
            subscriber.is_enabled = False
            return True
        
        if subscriber.is_traffic_exceeded:
            subscriber.status = SubscriberStatus.TRAFFIC_EXCEEDED
            subscriber.is_enabled = False
            return True
        
        return False


