"""
订阅用户管理API
"""
import uuid
import secrets
import base64
import io
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession

try:
    import qrcode
except ImportError:
    qrcode = None

from ..core.database import get_db
from ..models import Node, Server, Subscriber
from ..models.subscriber import SubscriberStatus
from ..schemas.subscriber import (
    SubscriberCreate,
    SubscriberUpdate,
    SubscriberResponse,
    SubscriberListResponse,
    SubscriberSettingsRequest,
    SubscriberBatchCreate,
    SubscriberShareInfo,
    SubscriberTrafficInfo
)
from ..services.subscriber_service import SubscriberService
from .auth import get_current_user

router = APIRouter()


def generate_uuid() -> str:
    """生成UUID"""
    return str(uuid.uuid4())


def generate_token() -> str:
    """生成订阅token"""
    return secrets.token_urlsafe(32)


async def build_subscriber_response(
    subscriber: Subscriber,
    node: Node,
    server: Server,
    base_url: str = ""
) -> SubscriberResponse:
    """构建订阅用户响应"""
    # 生成分享链接
    share_link = generate_share_link(subscriber, node, server)
    subscribe_url = f"{base_url}/api/subscribe/user/{subscriber.subscribe_token}" if base_url else ""
    
    return SubscriberResponse(
        id=subscriber.id,
        node_id=subscriber.node_id,
        name=subscriber.name,
        email=subscriber.email,
        remark=subscriber.remark,
        uuid=subscriber.uuid,
        subscribe_token=subscriber.subscribe_token,
        traffic_up=subscriber.traffic_up or 0,
        traffic_down=subscriber.traffic_down or 0,
        traffic_total=subscriber.traffic_total or 0,
        traffic_limit=subscriber.traffic_limit,
        last_traffic_sync=subscriber.last_traffic_sync,
        expire_at=subscriber.expire_at,
        status=subscriber.status,
        is_enabled=subscriber.is_enabled,
        is_expired=subscriber.is_expired,
        is_traffic_exceeded=subscriber.is_traffic_exceeded,
        node_name=node.name if node else None,
        server_ip=server.ip if server else None,
        listen_port=node.listen_port if node else None,
        share_link=share_link,
        subscribe_url=subscribe_url,
        created_at=subscriber.created_at,
        updated_at=subscriber.updated_at
    )


def generate_share_link(subscriber: Subscriber, node: Node, server: Server) -> str:
    """生成VLESS分享链接"""
    if not node or not server:
        return ""
    
    # VLESS Reality 链接格式
    link = f"vless://{subscriber.uuid}@{server.ip}:{node.listen_port}"
    link += f"?encryption=none&flow=xtls-rprx-vision&security=reality"
    link += f"&sni={node.reality_dest.split(':')[0] if node.reality_dest else 'www.apple.com'}"
    link += f"&fp=chrome&pbk={node.reality_public_key}"
    link += f"&sid={node.reality_short_id}&type=tcp"
    link += f"#{subscriber.name}"
    
    return link


@router.get("", response_model=SubscriberListResponse)
async def list_subscribers(
    page: int = Query(1, ge=1),
    page_size: int = Query(20, ge=1, le=100),
    node_id: Optional[int] = None,
    keyword: Optional[str] = None,
    status: Optional[str] = None,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """获取订阅用户列表"""
    query = select(Subscriber)
    count_query = select(func.count(Subscriber.id))
    
    if node_id is not None:
        query = query.where(Subscriber.node_id == node_id)
        count_query = count_query.where(Subscriber.node_id == node_id)
    
    if keyword:
        query = query.where(Subscriber.name.contains(keyword))
        count_query = count_query.where(Subscriber.name.contains(keyword))
    
    if status:
        query = query.where(Subscriber.status == status)
        count_query = count_query.where(Subscriber.status == status)
    
    total = await db.scalar(count_query)
    
    query = query.order_by(Subscriber.created_at.desc())
    query = query.offset((page - 1) * page_size).limit(page_size)
    
    result = await db.execute(query)
    subscribers = result.scalars().all()
    
    items = []
    for sub in subscribers:
        node = await db.get(Node, sub.node_id)
        server = await db.get(Server, node.server_id) if node else None
        items.append(await build_subscriber_response(sub, node, server))
    
    return SubscriberListResponse(items=items, total=total or 0)


@router.post("", response_model=SubscriberResponse)
async def create_subscriber(
    data: SubscriberCreate,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """创建订阅用户"""
    # 验证节点存在
    node = await db.get(Node, data.node_id)
    if not node:
        raise HTTPException(status_code=404, detail="节点不存在")
    
    server = await db.get(Server, node.server_id)
    
    # 创建订阅用户
    subscriber = Subscriber(
        node_id=data.node_id,
        name=data.name,
        email=data.email,
        remark=data.remark,
        uuid=generate_uuid(),
        subscribe_token=generate_token(),
        traffic_limit=data.traffic_limit,
        expire_at=data.expire_at,
        status=SubscriberStatus.ACTIVE,
        is_enabled=True
    )
    
    db.add(subscriber)
    await db.commit()
    await db.refresh(subscriber)
    
    # 同步Xray配置（添加新用户）
    subscriber_service = SubscriberService(db)
    await subscriber_service.sync_node_config(node.id)
    
    return await build_subscriber_response(subscriber, node, server)


@router.post("/batch", response_model=SubscriberListResponse)
async def batch_create_subscribers(
    data: SubscriberBatchCreate,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """批量创建订阅用户"""
    # 验证节点存在
    node = await db.get(Node, data.node_id)
    if not node:
        raise HTTPException(status_code=404, detail="节点不存在")
    
    server = await db.get(Server, node.server_id)
    
    # 获取当前该节点的用户数量，用于编号
    count_query = select(func.count(Subscriber.id)).where(Subscriber.node_id == data.node_id)
    current_count = await db.scalar(count_query) or 0
    
    items = []
    for i in range(data.count):
        subscriber = Subscriber(
            node_id=data.node_id,
            name=f"{data.name_prefix}{current_count + i + 1}",
            uuid=generate_uuid(),
            subscribe_token=generate_token(),
            traffic_limit=data.traffic_limit,
            expire_at=data.expire_at,
            status=SubscriberStatus.ACTIVE,
            is_enabled=True
        )
        db.add(subscriber)
        await db.flush()
        await db.refresh(subscriber)
        items.append(await build_subscriber_response(subscriber, node, server))
    
    await db.commit()
    
    # 同步Xray配置
    subscriber_service = SubscriberService(db)
    await subscriber_service.sync_node_config(data.node_id)
    
    return SubscriberListResponse(items=items, total=len(items))


@router.get("/{subscriber_id}", response_model=SubscriberResponse)
async def get_subscriber(
    subscriber_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """获取订阅用户详情"""
    subscriber = await db.get(Subscriber, subscriber_id)
    if not subscriber:
        raise HTTPException(status_code=404, detail="订阅用户不存在")
    
    node = await db.get(Node, subscriber.node_id)
    server = await db.get(Server, node.server_id) if node else None
    
    return await build_subscriber_response(subscriber, node, server)


@router.put("/{subscriber_id}", response_model=SubscriberResponse)
async def update_subscriber(
    subscriber_id: int,
    data: SubscriberUpdate,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """更新订阅用户"""
    subscriber = await db.get(Subscriber, subscriber_id)
    if not subscriber:
        raise HTTPException(status_code=404, detail="订阅用户不存在")
    
    # 更新字段
    update_data = data.model_dump(exclude_unset=True)
    for key, value in update_data.items():
        setattr(subscriber, key, value)
    
    # 更新状态
    if subscriber.is_expired:
        subscriber.status = SubscriberStatus.EXPIRED
    elif subscriber.is_traffic_exceeded:
        subscriber.status = SubscriberStatus.TRAFFIC_EXCEEDED
    elif not subscriber.is_enabled:
        subscriber.status = SubscriberStatus.DISABLED
    else:
        subscriber.status = SubscriberStatus.ACTIVE
    
    await db.commit()
    await db.refresh(subscriber)
    
    node = await db.get(Node, subscriber.node_id)
    server = await db.get(Server, node.server_id) if node else None
    
    return await build_subscriber_response(subscriber, node, server)


@router.delete("/{subscriber_id}")
async def delete_subscriber(
    subscriber_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """删除订阅用户"""
    subscriber = await db.get(Subscriber, subscriber_id)
    if not subscriber:
        raise HTTPException(status_code=404, detail="订阅用户不存在")
    
    node_id = subscriber.node_id
    
    await db.delete(subscriber)
    await db.commit()
    
    # 同步Xray配置（移除用户）
    subscriber_service = SubscriberService(db)
    await subscriber_service.sync_node_config(node_id)
    
    return {"message": "删除成功"}


@router.put("/{subscriber_id}/settings")
async def update_subscriber_settings(
    subscriber_id: int,
    data: SubscriberSettingsRequest,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """更新订阅用户设置（流量限制、到期时间）"""
    subscriber = await db.get(Subscriber, subscriber_id)
    if not subscriber:
        raise HTTPException(status_code=404, detail="订阅用户不存在")
    
    if data.traffic_limit is not None:
        subscriber.traffic_limit = data.traffic_limit if data.traffic_limit > 0 else None
    
    if data.expire_at is not None:
        subscriber.expire_at = data.expire_at
    
    # 更新状态
    if subscriber.is_expired:
        subscriber.status = SubscriberStatus.EXPIRED
    elif subscriber.is_traffic_exceeded:
        subscriber.status = SubscriberStatus.TRAFFIC_EXCEEDED
    else:
        subscriber.status = SubscriberStatus.ACTIVE
    
    await db.commit()
    
    return {"message": "设置已保存"}


@router.post("/{subscriber_id}/toggle")
async def toggle_subscriber(
    subscriber_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """启用/禁用订阅用户"""
    subscriber = await db.get(Subscriber, subscriber_id)
    if not subscriber:
        raise HTTPException(status_code=404, detail="订阅用户不存在")
    
    subscriber.is_enabled = not subscriber.is_enabled
    
    if subscriber.is_enabled:
        subscriber.status = SubscriberStatus.ACTIVE
    else:
        subscriber.status = SubscriberStatus.DISABLED
    
    await db.commit()
    
    # 同步Xray配置
    subscriber_service = SubscriberService(db)
    await subscriber_service.sync_node_config(subscriber.node_id)
    
    return {
        "message": "已启用" if subscriber.is_enabled else "已禁用",
        "is_enabled": subscriber.is_enabled
    }


@router.post("/{subscriber_id}/traffic/reset")
async def reset_subscriber_traffic(
    subscriber_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """重置订阅用户流量"""
    subscriber = await db.get(Subscriber, subscriber_id)
    if not subscriber:
        raise HTTPException(status_code=404, detail="订阅用户不存在")
    
    subscriber.traffic_up = 0
    subscriber.traffic_down = 0
    subscriber.traffic_total = 0
    
    # 如果之前因流量超限被禁用，恢复正常
    if subscriber.status == SubscriberStatus.TRAFFIC_EXCEEDED:
        subscriber.status = SubscriberStatus.ACTIVE
    
    await db.commit()
    
    return {"message": "流量已重置"}


@router.get("/{subscriber_id}/share", response_model=SubscriberShareInfo)
async def get_subscriber_share(
    subscriber_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """获取订阅用户分享信息"""
    subscriber = await db.get(Subscriber, subscriber_id)
    if not subscriber:
        raise HTTPException(status_code=404, detail="订阅用户不存在")
    
    node = await db.get(Node, subscriber.node_id)
    server = await db.get(Server, node.server_id) if node else None
    
    share_link = generate_share_link(subscriber, node, server)
    subscribe_url = f"/api/subscribe/user/{subscriber.subscribe_token}"
    
    # 生成二维码
    qrcode_base64 = ""
    if qrcode:
        qr = qrcode.QRCode(version=1, box_size=10, border=5)
        qr.add_data(share_link)
        qr.make(fit=True)
        img = qr.make_image(fill_color="black", back_color="white")
        
        buffer = io.BytesIO()
        img.save(buffer, format="PNG")
        qrcode_base64 = "data:image/png;base64," + base64.b64encode(buffer.getvalue()).decode()
    
    return SubscriberShareInfo(
        name=subscriber.name,
        share_link=share_link,
        subscribe_url=subscribe_url,
        qrcode_base64=qrcode_base64
    )


# ============ 节点下的订阅用户管理 ============

@router.post("/node/{node_id}/sync")
async def sync_node_subscribers(
    node_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """同步节点的订阅用户配置到服务器"""
    node = await db.get(Node, node_id)
    if not node:
        raise HTTPException(status_code=404, detail="节点不存在")
    
    subscriber_service = SubscriberService(db)
    success, message = await subscriber_service.sync_node_config(node_id)
    
    if success:
        return {"success": True, "message": message}
    else:
        raise HTTPException(status_code=500, detail=message)


@router.get("/node/{node_id}", response_model=SubscriberListResponse)
async def list_node_subscribers(
    node_id: int,
    page: int = Query(1, ge=1),
    page_size: int = Query(50, ge=1, le=100),
    db: AsyncSession = Depends(get_db),
    current_user: dict = Depends(get_current_user)
):
    """获取节点下的所有订阅用户"""
    node = await db.get(Node, node_id)
    if not node:
        raise HTTPException(status_code=404, detail="节点不存在")
    
    server = await db.get(Server, node.server_id)
    
    query = select(Subscriber).where(Subscriber.node_id == node_id)
    count_query = select(func.count(Subscriber.id)).where(Subscriber.node_id == node_id)
    
    total = await db.scalar(count_query)
    
    query = query.order_by(Subscriber.created_at.desc())
    query = query.offset((page - 1) * page_size).limit(page_size)
    
    result = await db.execute(query)
    subscribers = result.scalars().all()
    
    items = []
    for sub in subscribers:
        items.append(await build_subscriber_response(sub, node, server))
    
    return SubscriberListResponse(items=items, total=total or 0)

