#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
把 koshashri_entries.jsonl 转为 Goldendict/MDict 的 MDX 源 txt（鲁棒迭代版）

改进与修复：
- 修复 shades/subShades/subSubShades 中可能出现的 None 或非 dict 导致的 AttributeError
- 对 citations/references/pages/volume 等统一做类型与空值防护
- 依然完整展开层级义项、补齐 nestedVocables 与 trailing_articles、丰富别名（IAST ASCII/HK）

依赖：
pip install indic-transliteration

用法：
python koshashri_to_mdx.py -i koshashri_entries.jsonl -o koshashri.mdx.txt
"""
import argparse
import json
import re
import sys
import unicodedata
from html import escape
from typing import Dict, Iterable, List, Optional, Set, Tuple

try:
    from indic_transliteration import sanscript
    from indic_transliteration.sanscript import transliterate
except Exception:
    print("请先安装 indic-transliteration： pip install indic-transliteration", file=sys.stderr)
    raise

# -----------------------------
# 常量与工具
# -----------------------------

SUP_MAP = str.maketrans({
    "0": "⁰", "1": "¹", "2": "²", "3": "³", "4": "⁴",
    "5": "⁵", "6": "⁶", "7": "⁷", "8": "⁸", "9": "⁹",
})
SUPERSCRIPTS = "⁰¹²³⁴⁵⁶⁷⁸⁹"
MARKER_CHARS = set("*" + SUPERSCRIPTS + "0123456789")

ALLOWED_INLINE_TAGS: Set[str] = {"em", "i", "b", "strong", "sub", "sup", "small"}

DEV_RANGE = (
    ("\u0900", "\u097F"),  # Devanagari
    ("\u1CD0", "\u1CFF"),  # Vedic extensions
)

def ensure_list(x) -> List:
    if x is None:
        return []
    if isinstance(x, list):
        return x
    return [x]

def first_dict(lst: Iterable) -> Optional[Dict]:
    for it in ensure_list(lst):
        if isinstance(it, dict):
            return it
    return None

def is_deva_char(ch: str) -> bool:
    cp = ord(ch)
    for a, b in DEV_RANGE:
        if ord(a) <= cp <= ord(b):
            return True
    return False

def strip_diacritics(s: str) -> str:
    if not s:
        return ""
    return "".join(c for c in unicodedata.normalize("NFD", s) if unicodedata.category(c) != "Mn")

def html_sup_to_superscripts(s: str) -> str:
    def repl(m):
        inside = m.group(1)
        digits = re.sub(r"\D+", "", inside or "")
        return digits.translate(SUP_MAP)
    return re.sub(r"<\s*sup\s*>(.*?)<\s*/\s*sup\s*>", repl, s, flags=re.IGNORECASE | re.DOTALL)

def escape_allowing_inline_tags(s: str, allowed: Set[str] = ALLOWED_INLINE_TAGS) -> str:
    if not s:
        return ""
    esc = escape(s)
    if not allowed:
        return esc
    tags_pattern = "|".join(sorted(allowed, key=len, reverse=True))
    return re.sub(
        fr"&lt;(/?)\s*({tags_pattern})\s*&gt;",
        lambda m: f"<{m.group(1)}{m.group(2).lower()}>",
        esc,
        flags=re.IGNORECASE,
    )

def strip_html_tags_except_done_sup(s: str) -> str:
    return re.sub(r"<[^>]*?>", "", s)

def normalize_deva_raw_keep_markers(s: str) -> str:
    if not s:
        return ""
    s = html_sup_to_superscripts(s)
    s = strip_html_tags_except_done_sup(s)
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    return s

def split_prefix_markers(s: str) -> Tuple[str, str]:
    if not s:
        return "", ""
    s = s.lstrip()
    i = 0
    raw_prefix = []
    while i < len(s) and (s[i] in MARKER_CHARS or s[i].isspace()):
        raw_prefix.append(s[i])
        i += 1
        if i < len(s) and is_deva_char(s[i]):
            break
    prefix = "".join(raw_prefix).strip()
    core = s[i:].strip()
    prefix_digits_to_sup = prefix.replace(" ", "").translate(SUP_MAP)
    return prefix_digits_to_sup, core

def normalize_key_for_mdx(key: str) -> str:
    return re.sub(r"\s+", " ", (key or "").strip())

def compact_spaces(s: str) -> str:
    return re.sub(r"\s+", " ", (s or "").strip())

def is_nonempty(s: Optional[str]) -> bool:
    return bool(s and s.strip() and re.search(r"\S", s or ""))

def join_nonempty(parts: Iterable[str], sep: str) -> str:
    return sep.join([p for p in parts if is_nonempty(p)])

# -----------------------------
# 词头括注展开
# -----------------------------

PAREN_RX = re.compile(r"""
    (?P<head>^[^()]*?)
    KATEX_INLINE_OPEN
        (?P<inside>[^()]+)
    KATEX_INLINE_CLOSE
    (?P<tail>.*)$
""", re.VERBOSE)

def expand_parenthetical_variants(core: str) -> List[str]:
    core = compact_spaces(core)
    if not core:
        return []
    m = PAREN_RX.search(core)
    if not m:
        return [core]
    head = (m.group("head") or "").strip()
    inside = (m.group("inside") or "").strip()
    tail = (m.group("tail") or "").strip()
    variants = [v.strip() for v in re.split(r"\|", inside) if v.strip()]
    results: Set[str] = set()
    results.add((head + tail).strip())
    for v in variants:
        results.add((head + v + tail).strip())
    return sorted(results)

# -----------------------------
# 转写
# -----------------------------

def to_slp1(deva: str) -> str:
    if not deva:
        return ""
    out = transliterate(deva, sanscript.DEVANAGARI, sanscript.SLP1)
    out = out.replace(" ", "").replace("-", "")
    return out

def to_iast(deva: str) -> str:
    if not deva:
        return ""
    out = transliterate(deva, sanscript.DEVANAGARI, sanscript.IAST)
    out = re.sub(r"\s+", " ", out).strip()
    out = out.replace("-", "").strip()
    return out

def to_hk(deva: str) -> str:
    if not deva:
        return ""
    out = transliterate(deva, sanscript.DEVANAGARI, sanscript.HK)
    out = out.replace(" ", "").replace("-", "")
    return out

# -----------------------------
# POS
# -----------------------------

def get_pos(entry: Dict) -> str:
    g0 = compact_spaces(entry.get("grammaticalCategory") or "")
    g1 = compact_spaces(entry.get("grammaticalCategoryFullName") or "")
    sub = compact_spaces(entry.get("subCategory") or "")
    items: List[str] = []
    for v in [g0, g1, sub]:
        if v and v not in items:
            items.append(v)
    raw = join_nonempty(items, " · ")
    return escape_allowing_inline_tags(raw)

# -----------------------------
# 引文与参考
# -----------------------------

def guess_lang_tag(s: str) -> str:
    if any(is_deva_char(ch) for ch in s or ""):
        return "sa-Deva"
    return "sa-Latn"

def build_ref_span(r: Dict) -> Optional[str]:
    if not isinstance(r, dict):
        return None
    book = compact_spaces(r.get("referenceBook") or "")
    if not book:
        return None
    modes = compact_spaces(" ".join(ensure_list(r.get("referenceMode"))))
    full = compact_spaces(r.get("referenceBookFullName") or "")
    author = compact_spaces(r.get("author") or "")
    year = compact_spaces(r.get("year") or "")
    mode_desc = compact_spaces(r.get("referenceModeDesc") or "")
    title_parts = []
    if full:
        title_parts.append(full)
    if author:
        title_parts.append(f"Author: {author}")
    if year:
        title_parts.append(f"Year: {year}")
    if mode_desc:
        title_parts.append(f"Mode: {mode_desc}")
    title_attr = escape("; ".join(title_parts)) if title_parts else ""
    book_html = f"<span class='ls'>{escape(book)}</span>"
    mode_html = escape_allowing_inline_tags(modes) if modes else ""
    if title_attr:
        return f"<span class='ref' title='{title_attr}'>{book_html} {mode_html}</span>"
    return f"<span class='ref'>{book_html} {mode_html}</span>"

def build_citations_html(cits: List[Dict]) -> str:
    rows: List[str] = []
    count = 0
    for c in ensure_list(cits):
        if not isinstance(c, dict):
            continue
        txt = compact_spaces(c.get("citationText") or "")
        refs_html_parts: List[str] = []
        for r in ensure_list(c.get("references")):
            piece = build_ref_span(r if isinstance(r, dict) else {})
            if piece:
                refs_html_parts.append(piece)
        if not is_nonempty(txt) and not refs_html_parts:
            continue
        # citation 正文（自动按 lang 切字体）
        txt_html = ""
        if is_nonempty(txt):
            lang_tag = guess_lang_tag(txt)
            txt_html = f"<div class='cit-txt' lang='{lang_tag}'>{escape_allowing_inline_tags(txt)}</div>"
        refs_html = "<div class='refs'>" + " ".join(refs_html_parts) + "</div>" if refs_html_parts else ""
        rows.append("<div class='cit'>" + txt_html + refs_html + "</div>")
        count += 1

    if not rows:
        return ""

    # 用 details/summary 实现可折叠（默认关闭）
    # 注意：.citations 容器仍保留，CSS 里只在打开时显示；summary 上有“▶”图标
    return (
        "<details class='cit-block'>"
        "<summary class='cit-sum'>"
        "<span class='tri' aria-hidden='true'></span>"
        "<span class='lab'>e.g.</span>"
        f"<span class='cnt'>{count}</span>"
        "</summary>"
        "<div class='citations'>" + "".join(rows) + "</div>"
        "</details>"
    )

# -----------------------------
# 义项构建（完整层级，含防御）
# -----------------------------

def build_subsubshades_html(subsub_list) -> str:
    parts: List[str] = []
    for ss in ensure_list(subsub_list):
        if not isinstance(ss, dict):
            continue
        v = compact_spaces(ss.get("subSubShadeValue") or "")
        cit = build_citations_html(ensure_list(ss.get("citations")))
        gloss = ensure_list(ss.get("gloss"))
        gloss_html = f" <span class='gloss'>[{escape_allowing_inline_tags('; '.join(gloss))}]</span>" if gloss else ""
        if v or cit or gloss_html:
            parts.append(f"<div class='subsubshade'><span class='def'>{escape_allowing_inline_tags(v)}</span>{gloss_html}{cit}</div>")
    if not parts:
        return ""
    return "<div class='subsubshades'>" + "".join(parts) + "</div>"

def build_subshades_html(subs_list, label_prefix: str) -> str:
    rows: List[str] = []
    for sub in ensure_list(subs_list):
        if not isinstance(sub, dict):
            continue
        v = compact_spaces(sub.get("subShadeValue") or "")
        cit = build_citations_html(ensure_list(sub.get("citations")))
        gloss = ensure_list(sub.get("gloss"))
        gloss_html = f" <span class='gloss'>[{escape_allowing_inline_tags('; '.join(gloss))}]</span>" if gloss else ""
        subsub_html = build_subsubshades_html(sub.get("subSubShades"))
        label_html = f"<span class='lbl'>{escape(label_prefix)}</span> " if label_prefix else ""
        def_html = f"<span class='def'>{escape_allowing_inline_tags(v)}</span>" if is_nonempty(v) else ""
        rows.append(f"<div class='subsense' data-key='{escape(label_prefix)}'>{label_html}{def_html}{gloss_html}{cit}{subsub_html}</div>")
    if not rows:
        return ""
    return "<div class='subsenses'>" + "".join(rows) + "</div>"

def build_shades_html(shades_list, sense_key: str) -> str:
    parts: List[str] = []
    for sh in ensure_list(shades_list):
        if not isinstance(sh, dict):
            continue
        shade_key = compact_spaces(sh.get("shadeKey") or "")
        label_prefix = (sense_key + shade_key) if (sense_key and shade_key) else (sense_key or shade_key)
        subs = sh.get("subShades")
        cit = build_citations_html(ensure_list(sh.get("citations")))
        gloss = ensure_list(sh.get("gloss"))
        gloss_html = f" <span class='gloss'>[{escape_allowing_inline_tags('; '.join(gloss))}]</span>" if gloss else ""
        sub_html = build_subshades_html(subs, label_prefix=label_prefix)
        shade_value = compact_spaces(sh.get("shadeValue") or "")
        shade_value_html = f"<span class='def'>{escape_allowing_inline_tags(shade_value)}</span>" if is_nonempty(shade_value) else ""
        if sub_html or shade_value_html or cit or gloss_html:
            parts.append(f"<div class='shade' data-key='{escape(shade_key)}'>{shade_value_html}{gloss_html}{sub_html}{cit}</div>")
    if not parts:
        return ""
    return "<div class='shades'>" + "".join(parts) + "</div>"

def build_meanings_html(entry: Dict) -> str:
    senses = ensure_list(entry.get("meaning"))
    pos_html_text = get_pos(entry)
    parts: List[str] = []
    if senses:
        for m in senses:
            if not isinstance(m, dict):
                continue
            key = compact_spaces(m.get("meaningKey") or "")
            mv = compact_spaces(m.get("meaningValue") or "")
            gloss = ensure_list(m.get("gloss"))
            gloss_html = f" <span class='gloss'>[{escape_allowing_inline_tags('; '.join(gloss))}]</span>" if gloss else ""
            cit_html = build_citations_html(ensure_list(m.get("citations")))
            shades_html = build_shades_html(m.get("shades"), sense_key=key)
            pos_html = f"<span class='pos'>{pos_html_text}</span> " if pos_html_text else ""
            label_html = f"<span class='sense-key'>{escape(key)}.</span> " if key else ""
            mv_html = f"<span class='def'>{escape_allowing_inline_tags(mv)}</span>" if is_nonempty(mv) else ""
            body = "".join([mv_html, gloss_html, shades_html, cit_html])

            if not is_nonempty(body):
                mb = compact_spaces(entry.get("meaningBlock") or "")
                ot = compact_spaces(entry.get("ocrText") or "")
                text = mb if len(mb) >= len(ot) else ot
                text = escape_allowing_inline_tags(text)
                if text:
                    body = f"<span class='def'>{text}</span>"
            if not body:
                body = "<span class='def'>—</span>"

            parts.append(f"<div class='sense' data-key='{escape(key)}'>{label_html}{pos_html}{body}</div>")
    else:
        mb = compact_spaces(entry.get("meaningBlock") or "")
        ot = compact_spaces(entry.get("ocrText") or "")
        text = mb if len(mb) >= len(ot) else ot
        text = escape_allowing_inline_tags(text)
        if text:
            parts.append(f"<div class='sense'><span class='def'>{text}</span></div>")
    if not parts:
        parts.append("<div class='sense'><span class='def'>—</span></div>")
    return "<div class='senses'>" + "".join(parts) + "</div>"

# -----------------------------
# 元信息与整体条目
# -----------------------------

def build_meta_html(entry: Dict, outer: Dict) -> str:
    eid = entry.get("id") or outer.get("id") or ""
    url = outer.get("url") or ""
    volume = entry.get("volume") if isinstance(entry.get("volume"), dict) else {}
    vol_no = volume.get("volumeNumber", "")
    vol_name = volume.get("name", "")
    p0 = first_dict(entry.get("pages"))
    page_no = p0.get("pageNumber", "") if p0 else ""
    col = p0.get("column", "") if p0 else ""
    url_html = f"<a class='src' href='{escape(url)}' target='_blank' rel='noopener'>source</a>" if url else ""
    vol_title = escape(join_nonempty([str(vol_no) if vol_no != "" else "", vol_name], " · "))
    return (
        f"<div class='meta' data-id='{escape(str(eid))}' "
        f"data-vol='{escape(str(vol_no))}' data-page='{escape(str(page_no))}' "
        f"data-col='{escape(str(col))}' data-vol-name='{vol_title}'>"
        f"{url_html}</div>"
    )

def build_entry_html(deva_display: str, iast_base: str, slp: str, entry: Dict, outer: Dict, stylesheet: Optional[str]) -> str:
    link_html = f"<link rel='stylesheet' href='{escape(stylesheet)}' type='text/css'>" if stylesheet else ""
    hw = (
        "<div class='mw-entry'>"
        "<div class='hw'>"
        f"<span class='hw-deva' lang='sa-Deva'>{escape(deva_display)}</span>"
        f"<span class='hw-iast' lang='sa-Latn' data-scheme='IAST'>{escape(iast_base)}</span>"
        f"<span class='hw-slp' lang='sa-Latn' data-scheme='SLP1'>{escape(slp)}</span>"
        "</div>"
    )
    meanings = build_meanings_html(entry)
    meta = build_meta_html(entry, outer)
    return link_html + hw + meanings + meta + "</div>"

# -----------------------------
# Emit 到 MDX
# -----------------------------

def emit_one_article(out, key: str, html_body: str):
    out.write("</>\n")
    out.write(f"{key}\n")
    out.write(html_body)
    out.write("\n")

def emit_link_alias(out, alias_key: str, target_key: str):
    out.write("</>\n")
    out.write(f"{alias_key}\n")
    out.write(f"@@@LINK={target_key}\n")

# -----------------------------
# 递归遍历 entry / nestedVocables / trailing_articles
# -----------------------------

def iter_entry_like_from_entry(entry: Dict) -> Iterable[Dict]:
    if not isinstance(entry, dict):
        return
    yield entry
    for t in ensure_list(entry.get("trailing_articles")):
        if isinstance(t, dict):
            yield from iter_entry_like_from_entry(t)

def iter_all_entries(obj: Dict) -> Iterable[Dict]:
    main_entry = obj.get("entry") or {}
    if isinstance(main_entry, dict):
        yield from iter_entry_like_from_entry(main_entry)
    for nv in ensure_list(obj.get("nestedVocables")):
        if isinstance(nv, dict):
            yield from iter_entry_like_from_entry(nv)

# -----------------------------
# 核心处理
# -----------------------------

def build_aliases(prefix: str, core_form: str, deva_display: str, iast_base: str, slp_key: str) -> List[str]:
    aliases: List[str] = []
    def add_alias(a: str):
        a = normalize_key_for_mdx(a)
        if a and a != slp_key and a not in aliases:
            aliases.append(a)
    # Deva
    add_alias(core_form)
    add_alias(deva_display)
    # IAST
    add_alias(iast_base)
    iast_marked = (prefix + iast_base).strip() if prefix else ""
    add_alias(iast_marked)
    # IAST ASCII
    iast_ascii = strip_diacritics(iast_base)
    add_alias(iast_ascii)
    iast_ascii_marked = strip_diacritics(iast_marked)
    add_alias(iast_ascii_marked)
    # HK
    try:
        hk = to_hk(core_form)
        add_alias(hk)
        hk_marked = (prefix + hk).strip() if prefix else ""
        add_alias(hk_marked)
    except Exception:
        pass
    return aliases

def choose_headword(obj: Dict, entry: Dict) -> str:
    return entry.get("entryWord") or entry.get("vocable") or obj.get("vocable") or ""

def process_entry_like(obj: Dict, entry: Dict, stylesheet: Optional[str]) -> List[Tuple[str, str, List[str]]]:
    deva_raw = choose_headword(obj, entry)
    deva_raw = normalize_deva_raw_keep_markers(deva_raw)
    prefix, core = split_prefix_markers(deva_raw)
    if not core:
        return []
    core_forms = expand_parenthetical_variants(core) or [core]
    results: List[Tuple[str, str, List[str]]] = []
    for core_form in core_forms:
        deva_display = (prefix + core_form).strip()
        try:
            slp = normalize_key_for_mdx(to_slp1(core_form))
            iast_base = normalize_key_for_mdx(to_iast(core_form))
        except Exception:
            continue
        if not slp:
            continue
        html_body = build_entry_html(deva_display, iast_base, slp, entry, obj, stylesheet)
        alias_keys = build_aliases(prefix, core_form, deva_display, iast_base, slp)
        results.append((slp, html_body, alias_keys))
    return results

def process_jsonl_line(obj: Dict, stylesheet: Optional[str]) -> List[Tuple[str, str, List[str]]]:
    results: List[Tuple[str, str, List[str]]] = []
    seen_entry_ids: Set[str] = set()
    for e in iter_all_entries(obj):
        if not isinstance(e, dict):
            continue
        eid = str(e.get("id") or "")
        sig = eid + "\t" + compact_spaces(choose_headword(obj, e))
        if sig in seen_entry_ids:
            continue
        seen_entry_ids.add(sig)
        results.extend(process_entry_like(obj, e, stylesheet))
    return results

# -----------------------------
# main
# -----------------------------

def main():
    ap = argparse.ArgumentParser(description="Convert koshashri_entries.jsonl to MDX source txt (rich, hierarchical, robust)")
    ap.add_argument("-i", "--input", required=True, help="Path to koshashri_entries.jsonl")
    ap.add_argument("-o", "--output", required=True, help="Path to output mdx.txt")
    ap.add_argument("-s", "--stylesheet", default="mw.css", help="Stylesheet link placed in each article (set empty to disable)")
    args = ap.parse_args()

    stylesheet = args.stylesheet.strip() or None

    emitted_main: Set[Tuple[str, int]] = set()
    emitted_alias: Set[Tuple[str, str]] = set()

    with open(args.input, "r", encoding="utf-8") as fin, open(args.output, "w", encoding="utf-8", newline="\n") as fout:
        for ln in fin:
            ln = ln.strip()
            if not ln:
                continue
            try:
                obj = json.loads(ln)
            except Exception:
                continue

            items = process_jsonl_line(obj, stylesheet)
            for slp_key, html_body, alias_keys in items:
                uniq_key = (slp_key, hash(html_body))
                if uniq_key not in emitted_main:
                    emit_one_article(fout, slp_key, html_body)
                    emitted_main.add(uniq_key)

                for ak in alias_keys:
                    ak_norm = normalize_key_for_mdx(ak)
                    if not ak_norm:
                        continue
                    pair = (ak_norm, slp_key)
                    if pair not in emitted_alias:
                        emit_link_alias(fout, ak_norm, slp_key)
                        emitted_alias.add(pair)

    print(f"Done. Wrote MDX source to: {args.output}")

if __name__ == "__main__":
    main()