#!/usr/bin/env python3
"""
RAG 记忆检索脚本
小桉用：用户消息 → embedding → Qdrant 检索 → 返回记忆片段
调用方式：python3 rag_query.py "用户的问题"
"""
import sys
import json
import urllib.request
import urllib.error
import os

EMBEDDING_URL = os.environ.get("RAG_EMBEDDING_URL", "http://100.105.196.63:8081")
QDRANT_URL = os.environ.get("RAG_QDRANT_URL", "http://100.105.196.63:6333")
COLLECTION = "memory"
TOP_K = 5
SCORE_THRESHOLD = 0.65  # 宽松一点，后续可以调


def embed(text):
    """文本 → 512 维向量"""
    data = json.dumps({"text": text}).encode()
    req = urllib.request.Request(
        f"{EMBEDDING_URL}/embed",
        data=data,
        headers={"Content-Type": "application/json"},
        method="POST"
    )
    with urllib.request.urlopen(req, timeout=10) as resp:
        return json.loads(resp.read())["vector"]


def search(vector):
    """向量 → Qdrant 检索"""
    data = json.dumps({
        "vector": vector,
        "limit": TOP_K,
        "score_threshold": SCORE_THRESHOLD,
        "with_payload": True
    }).encode()
    req = urllib.request.Request(
        f"{QDRANT_URL}/collections/{COLLECTION}/points/search",
        data=data,
        headers={"Content-Type": "application/json"},
        method="POST"
    )
    with urllib.request.urlopen(req, timeout=10) as resp:
        return json.loads(resp.read())["result"]


def format_memory(hits):
    """把检索结果格式化成可读的记忆文本"""
    if not hits:
        return None

    lines = []
    for i, hit in enumerate(hits, 1):
        p = hit["payload"]
        score = hit["score"]
        content = p.get("content", "")
        source = p.get("source", "未知来源")
        date = p.get("date", "")
        tags = p.get("tags", [])
        tags_str = "、".join(tags) if tags else ""

        parts = [f"[{i}] {content}"]
        if source:
            parts.append(f"    来源: {source}")
        if date:
            parts.append(f"    日期: {date}")
        if tags_str:
            parts.append(f"    标签: {tags_str}")
        parts.append(f"    相关度: {score:.2f}")
        lines.append("\n".join(parts))

    return "\n---\n".join(lines)


def main():
    if len(sys.argv) < 2:
        print("用法: python3 rag_query.py '用户问题'", file=sys.stderr)
        sys.exit(1)

    query = " ".join(sys.argv[1:])

    try:
        vec = embed(query)
        hits = search(vec)
        result = format_memory(hits)

        if result:
            print(result)
        else:
            print("NO_RELEVANT_MEMORY")
    except urllib.error.HTTPError as e:
        print(f"HTTP 错误: {e.code} {e.reason}", file=sys.stderr)
        sys.exit(2)
    except urllib.error.URLError as e:
        print(f"网络错误: {e.reason}", file=sys.stderr)
        sys.exit(2)
    except Exception as e:
        print(f"检索异常: {e}", file=sys.stderr)
        sys.exit(2)


if __name__ == "__main__":
    main()
