# !pip install aiohttp tqdm

import os
import json
import asyncio
import random
from collections import OrderedDict
from typing import Any, Dict, List, Optional

from aiohttp import ClientSession, TCPConnector, ClientTimeout
from tqdm.auto import tqdm

# ========== 可调参数 ==========
INPUT_JSON = "/content/drive/MyDrive/merged_vocables.json"
OUTPUT_JSONL = "/content/drive/MyDrive/koshashri_entries.jsonl"

BASE = "https://koshashri-dc.ac.in"
MAX_CONCURRENCY = 30          # 并发数，想更快可以调大，比如 50（注意对方服务器限制）
RETRIES = 4                   # 每条最多重试次数
BACKOFF_BASE = 0.8            # 指数退避的基础等待
TIMEOUT_TOTAL = 25            # 单请求总超时(秒)
# ============================

# 建议在 Colab 里先挂载 Drive（已挂载可忽略）
try:
    from google.colab import drive  # noqa
    drive.mount('/content/drive', force_remount=False)
except Exception:
    pass

HEADERS = {
    "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) "
                  "AppleWebKit/537.36 (KHTML, like Gecko) "
                  "Chrome/124.0 Safari/537.36",
    "Accept": "application/json, text/plain, */*",
    "Accept-Language": "en-US,en;q=0.9"
}

def load_input(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    # 去重：按 id 保留首次出现
    dedup = OrderedDict()
    for item in data:
        _id = item.get("id")
        if not _id:
            continue
        if _id not in dedup:
            dedup[_id] = {
                "id": _id,
                "vocable": item.get("vocable", None)
            }
    return list(dedup.values())

def load_done_ids(out_path: str) -> set:
    done = set()
    if not os.path.exists(out_path):
        return done
    with open(out_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                _id = obj.get("id")
                if _id:
                    done.add(_id)
            except Exception:
                # 如果某行损坏，忽略继续
                continue
    return done

def parse_vocable_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
    """
    解析 /search/vocable/{id} 返回的 payload:
      - payload['object'] 是一个 JSON 字符串
      - payload['nestedVocables'] 是包含若干对象的数组，每个对象的 object 也是 JSON 字符串
    """
    entry = None
    nested_list = []

    obj_raw = payload.get("object")
    if isinstance(obj_raw, str):
        try:
            entry = json.loads(obj_raw)
        except Exception:
            entry = None

    nested_raw = payload.get("nestedVocables") or []
    for nv in nested_raw:
        try:
            nobj_raw = nv.get("object")
            if isinstance(nobj_raw, str):
                nested_list.append(json.loads(nobj_raw))
        except Exception:
            # 单个 nested 解析失败忽略
            pass

    return {
        "entry": entry,
        "nestedVocables": nested_list
    }

async def fetch_one(session: ClientSession, _id: str) -> Dict[str, Any]:
    url = f"{BASE}/search/vocable/{_id}"
    for attempt in range(1, RETRIES + 1):
        try:
            async with session.get(url, headers=HEADERS, timeout=ClientTimeout(total=TIMEOUT_TOTAL)) as resp:
                status = resp.status
                if status == 200:
                    # 有些服务器 content-type 可能不规范，用 content_type=None
                    try:
                        payload = await resp.json(content_type=None)
                    except Exception:
                        text = await resp.text()
                        payload = json.loads(text)

                    parsed = parse_vocable_payload(payload)
                    return {
                        "id": _id,
                        "url": url,
                        "status": status,
                        **parsed
                    }

                # 404/403 多为不存在或权限问题，直接返回不再重试
                if status in (403, 404):
                    return {"id": _id, "url": url, "status": status, "error": f"HTTP {status}"}

                # 其它状态，重试
                await asyncio.sleep(BACKOFF_BASE * (2 ** (attempt - 1)) + random.random() * 0.3)

        except Exception as e:
            if attempt >= RETRIES:
                return {"id": _id, "url": url, "status": None, "error": repr(e)}
            await asyncio.sleep(BACKOFF_BASE * (2 ** (attempt - 1)) + random.random() * 0.3)

    # 理论上不会到这里
    return {"id": _id, "url": url, "status": None, "error": "Unknown error"}

async def bounded_fetch(session: ClientSession, sem: asyncio.Semaphore, _id: str):
    async with sem:
        return await fetch_one(session, _id)

async def main():
    # 1) 读取输入并去重
    items = load_input(INPUT_JSON)  # [{'id':..., 'vocable':...}, ...]
    id2vocable = {x["id"]: x.get("vocable") for x in items}

    # 2) 断点续爬：过滤已完成 id
    done_ids = load_done_ids(OUTPUT_JSONL)
    to_fetch = [x["id"] for x in items if x["id"] not in done_ids]

    if not to_fetch:
        print("没有需要新抓取的 ID，已全部完成。")
        return

    print(f"总计: {len(items)}, 已完成: {len(done_ids)}, 待抓取: {len(to_fetch)}")

    # 3) 创建 session；VERIFY 关闭：TCPConnector(ssl=False)
    connector = TCPConnector(ssl=False, limit=MAX_CONCURRENCY * 2, ttl_dns_cache=300)
    sem = asyncio.Semaphore(MAX_CONCURRENCY)

    # 4) 并发抓取并边抓边写入
    os.makedirs(os.path.dirname(OUTPUT_JSONL), exist_ok=True)
    with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout:
        async with ClientSession(connector=connector) as session:
            tasks = [asyncio.create_task(bounded_fetch(session, sem, _id)) for _id in to_fetch]
            for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Fetching"):
                result = await coro
                # 附加原 vocable
                result["vocable"] = id2vocable.get(result["id"])

                # 持续写入 JSONL（确保非 ASCII 保留）
                fout.write(json.dumps(result, ensure_ascii=False) + "\n")

    print("完成！输出文件：", OUTPUT_JSONL)

if __name__ == "__main__":
    asyncio.run(main())