import re
from typing import List, Tuple, Optional
from difflib import SequenceMatcher
Opcode = Tuple[str, int, int, str]  # (tag, start, end, text)


class ExtractableText:
    """
    支持 HTML 模式抽取并生成 opcode。
    - insert/delete/replace opcode 元组格式为 (tag, start, end, text)，半开区间 [start, end)
    - clean 文本坐标系用于 bounds 返回与 diff/merge 的后续处理
    """
    TAG_RE = re.compile(r"<[^>]+>")

    def __init__(self, text: str):
        self.original = text
        self.clean: str = ""
        self.opcodes: List[Opcode] = []  # will be sorted by start
        # internal: list of tag records created while parsing
        # each record: dict {name, open_pos(clean), close_pos(clean), depth}
        self._tag_records: List[dict] = []

    @staticmethod
    def _parse_tag_name(tag_text: str) -> str:
        m = re.match(r"^<\s*/?\s*([A-Za-z0-9]+)", tag_text)
        return m.group(1).lower() if m else ""

    @staticmethod
    def _is_closing_tag(tag_text: str) -> bool:
        return re.match(r"^<\s*/", tag_text) is not None

    @staticmethod
    def _is_self_closing(tag_text: str) -> bool:
        # either ends with '/>' or common void elements
        if tag_text.rstrip().endswith("/>"):
            return True
        name = ExtractableText._parse_tag_name(tag_text)
        voids = {"br", "hr", "img", "input", "meta", "link", "area", "base", "col", "embed", "param", "source", "track", "wbr"}
        return name in voids

    def extract_html(self) -> Tuple[str, List[Opcode]]:
        """
        从 self.original 抽取 HTML 标签：
        - 普通开/闭标签 => 生成 insert opcode（并不放入 clean）
        - <br> / <br/> / 变体 => 在 clean 写入 '\n'，生成 replace opcode 保留原标签文本（原样）
        - 记录每个标签在 clean 中的 start/end（半开区间）
        返回 (clean_text, opcodes)
        """
        orig = self.original
        clean_parts: List[str] = []
        cursor = 0
        clean_pos = 0  # current length of clean_parts joined
        stack: List[dict] = []  # stack of open tag records (each with name, start_clean, depth, raw_open_text)

        for m in self.TAG_RE.finditer(orig):
            tag_text = m.group()
            text_before = orig[cursor:m.start()]
            if text_before:
                clean_parts.append(text_before)
                clean_pos += len(text_before)
            # process tag_text
            if self._is_closing_tag(tag_text):
                # closing tag: pop matching open if possible
                name = self._parse_tag_name(tag_text)
                # find last open with same name (stack behavior)
                if stack:
                    # pop until match; if no match, treat as insert at current pos
                    popped = None
                    while stack:
                        top = stack.pop()
                        if top["name"] == name:
                            popped = top
                            break
                        else:
                            # unmatched open - we'll close it implicitly (set end) and continue popping
                            top["end"] = clean_pos
                            self._tag_records.append(top)
                            # also generate a closing insert for that unmatched tag to preserve reversibility
                            self.opcodes.append(("insert", clean_pos, clean_pos, f"</{top['name']}>"))
                    if popped:
                        popped["end"] = clean_pos
                        self._tag_records.append(popped)
                        # generate insert opcode for this closing tag at current clean_pos
                        self.opcodes.append(("insert", clean_pos, clean_pos, tag_text))
                    else:
                        # no matching open - just generate insert at current clean pos
                        self.opcodes.append(("insert", clean_pos, clean_pos, tag_text))
                else:
                    # stray closing tag -> treat as insert
                    self.opcodes.append(("insert", clean_pos, clean_pos, tag_text))
            else:
                # opening or self-closing
                name = self._parse_tag_name(tag_text)
                if self._is_self_closing(tag_text):
                    # special-case: <br> variants -> produce newline in clean and replace opcode
                    if name == "br":
                        # append newline to clean
                        clean_parts.append("\n")
                        # replace opcode covering that new char
                        self.opcodes.append(("replace", clean_pos, clean_pos + 1, tag_text))
                        clean_pos += 1
                    else:
                        # other self-closing tags: treat as insert at current pos
                        self.opcodes.append(("insert", clean_pos, clean_pos, tag_text))
                else:
                    # normal opening tag: push to stack, record start = current clean pos
                    depth = len(stack)
                    rec = {"name": name, "start": clean_pos, "end": None, "depth": depth}
                    stack.append(rec)
                    # generate insert opcode for opening tag at current pos
                    self.opcodes.append(("insert", clean_pos, clean_pos, tag_text))
            cursor = m.end()

        # remaining text after last tag
        if cursor < len(orig):
            tail = orig[cursor:]
            clean_parts.append(tail)
            clean_pos += len(tail)

        # Close any unclosed tags on stack (set their end to len(clean))
        final_clean_len = clean_pos
        while stack:
            top = stack.pop()
            top["end"] = final_clean_len
            self._tag_records.append(top)
            # add a closing insert opcode to be reversible
            self.opcodes.append(("insert", final_clean_len, final_clean_len, f"</{top['name']}>"))

        # sort opcodes by start (stable)
        self.opcodes.sort(key=lambda o: (o[1], 0 if o[0] == "insert" else 1))
        self.clean = "".join(clean_parts)
        # ensure tag_records sorted by start and depth reflects nesting (we already had depth)
        self._tag_records.sort(key=lambda r: (r["start"], -r["depth"]))
        return self.clean, list(self.opcodes)

    def restore(self, clean_text: Optional[str] = None) -> str:
        """
        根据 opcodes 将 clean_text（或 self.clean）还原为富文本。
        支持 insert / delete / replace。
        opcodes 必须是基于 clean 文本坐标的（即半开区间）。
        """
        base = clean_text if clean_text is not None else self.clean
        s = list(base)
        offset = 0
        # apply in order — opcodes 已按 start 升序
        for tag, start, end, text in self.opcodes:
            if tag == "insert":
                idx = start + offset
                s[idx:idx] = list(text)
                offset += len(text)
            elif tag == "delete":
                idx0 = start + offset
                idx1 = end + offset
                del s[idx0:idx1]
                offset -= (idx1 - idx0)
            elif tag == "replace":
                idx0 = start + offset
                idx1 = end + offset
                # replace [idx0:idx1) with text
                s[idx0:idx1] = list(text)
                offset += len(text) - (idx1 - idx0)
            else:
                raise ValueError(f"Unknown opcode tag: {tag}")
        return "".join(s)

    def get_tag_context(self, pos: int) -> Tuple[List[str], List[int], List[int]]:
        """
        返回 (path, left_bounds, right_bounds)：
        - path: 按层级从外到内的标签名列表（最深在末尾）
        - left_bounds, right_bounds: 与 path 一一对应，基于 clean 文本坐标，right 为半开区间
        如果 pos 超出范围或不在任何标签下，返回三个空列表/元组。
        """
        if pos < 0 or pos > len(self.clean):
            return [], [], []

        # find all records that enclose pos: start <= pos < end
        containing = [r for r in self._tag_records if r["start"] <= pos < (r["end"] if r["end"] is not None else len(self.clean))]
        if not containing:
            return [], [], []
        # sort by start ascending (outer -> inner), but to ensure nesting we use depth ascending
        containing.sort(key=lambda r: (r["start"], r["depth"]))
        path = [r["name"] for r in containing]
        lefts = [r["start"] for r in containing]
        rights = [r["end"] for r in containing]
        return path, lefts, rights

    def contexts_equal(self, pos_a: int, pos_b: int) -> bool:
        """
        判断 pos_a 与 pos_b 是否处于完全相同的路径与 bounds（即不跨标签）。
        返回 True 表示“相同上下文”（不跨），False 表示“不同上下文”（跨标签）
        """
        pa, la, ra = self.get_tag_context(pos_a)
        pb, lb, rb = self.get_tag_context(pos_b)
        return (pa == pb) and (la == lb) and (ra == rb)
    
    def compare(self, other):
        sm = SequenceMatcher(None, self.clean, other)
        return sm.get_opcodes()

        

    # 方便：比较区间[start,end)是否跨标签
    def range_crosses_tag(self, start: int, end: int) -> bool:
        """
        判断区间 [start, end) 是否跨标签（即区间两端上下文不相同）。
        语义：若 start/end 在相同上下文则返回 False；否则 True。
        注意：如果区间是零宽（start==end），则比较 start 与 start (or start-1?), 这里采用比较 start 和 max(start, end-1)
        """
        if start < 0 or end < 0 or start > len(self.clean) or end > len(self.clean) or start > end:
            raise ValueError("Invalid start/end for clean text")
        if start == end:
            # zero-width：比较 pos start 与 pos start-1（若 start>0）或 start 与 start (both 0)
            if start == 0:
                return not self.contexts_equal(0, 0)
            else:
                return not self.contexts_equal(start - 1, start)
        # compare contexts at start and end-1
        return not self.contexts_equal(start, end - 1)

    # ---------- 抽取/恢复等原有方法略（请保持你已有实现） ----------
    # 这里假设 extract_html 已生成 self.clean, self.opcodes（代表标签插入/replace等），和 self._tag_records
    # 同前文的实现，我们不重复写 extract_html/restore 的全部细节，这里重点是修改与偏移维护

    # -------------------------
    # Internal helpers for adjusting tag_records and opcodes
    # -------------------------
    def _shift_tag_records_after_insert(self, pos: int, delta: int) -> None:
        """Insert at pos, length delta: adjust tag_records in-place."""
        for rec in self._tag_records:
            s = rec["start"]
            e = rec["end"]
            if s >= pos:
                rec["start"] = s + delta
            # if insertion is inside the tag content, extend end
            if e >= pos:
                rec["end"] = e + delta

    def _adjust_tag_records_after_delete(self, start: int, end: int) -> None:
        """Delete [start,end): adjust tag_records consistently."""
        D = end - start
        new_records = []
        for rec in self._tag_records:
            s = rec["start"]
            e = rec["end"]
            # completely before deletion
            if e <= start:
                new_records.append(rec)
            # completely after deletion
            elif s >= end:
                rec["start"] = s - D
                rec["end"] = e - D
                new_records.append(rec)
            else:
                # overlap cases
                new_s = s
                new_e = e
                if s < start < e and e <= end:
                    # truncates tail
                    new_e = start
                elif s >= start and e <= end:
                    # wholly removed -> collapse to start
                    new_s = start
                    new_e = start
                elif s >= start and e > end:
                    # truncates head
                    new_s = start
                    new_e = e - D
                elif s < start and e > end:
                    # deletion in middle of tag content -> shrink by D
                    new_e = e - D
                # normalize
                if new_e < new_s:
                    new_e = new_s
                rec["start"] = new_s
                rec["end"] = new_e
                new_records.append(rec)
        # replace records
        self._tag_records = new_records

    def _shift_opcodes_after_insert(self, pos: int, delta: int) -> None:
        """Shift existing opcodes positions after an insert at pos."""
        new_ops = []
        for tag, s, e, t in self.opcodes:
            if s >= pos:
                s += delta
            if e >= pos:
                e += delta
            new_ops.append((tag, s, e, t))
        self.opcodes = new_ops

    def _adjust_opcodes_after_delete(self, start: int, end: int) -> None:
        """Adjust existing opcodes after deleting [start,end). Try to keep opcodes consistent."""
        D = end - start
        new_ops = []
        for tag, s, e, t in self.opcodes:
            # cases relative to deletion interval
            if e <= start:
                # entirely before deletion: unchanged
                new_ops.append((tag, s, e, t))
            elif s >= end:
                # entirely after deletion: shift left
                new_ops.append((tag, s - D, e - D, t))
            else:
                # overlaps deletion
                # We'll collapse the overlapping part to the left boundary (start)
                new_s = s
                new_e = e
                if s < start and e <= end:
                    # left part remains
                    new_e = start
                elif s >= start and e <= end:
                    # fully removed: collapse to start (zero width)
                    new_s = start
                    new_e = start
                elif s >= start and e > end:
                    # right part remains, shift left
                    new_s = start
                    new_e = e - D
                elif s < start and e > end:
                    # middle removed, shrink e by D
                    new_e = e - D
                # ensure non-negative and normalized
                if new_e < new_s:
                    new_e = new_s
                new_ops.append((tag, new_s, new_e, t))
        self.opcodes = new_ops

    # -------------------------
    # Public editable operations (reversible)
    # -------------------------
    def insert_temp(self, pos: int, text: str):
        if pos < 0 or pos > len(self.clean):
            raise ValueError("insert position out of bounds")
        # modify clean
        self.clean = self.clean[:pos] + text + self.clean[pos:]
        L = len(text)
        # append a reversible opcode describing this edit
        self.opcodes.append(("insert", pos, pos, text))
        # shift existing tag_records & opcodes (so their coordinates remain consistent with new clean)
        self._shift_tag_records_after_insert(pos, L)
        self._shift_opcodes_after_insert(pos, L)

    def delete_temp(self, start: int, end: int):
        if start < 0 or end > len(self.clean) or start > end:
            raise ValueError("invalid delete range")
        # modify clean
        self.clean = self.clean[:start] + self.clean[end:]
        # append delete opcode (reversible)
        self.opcodes.append(("delete", start, end, ""))
        # adjust tag_records & opcodes
        self._adjust_tag_records_after_delete(start, end)
        self._adjust_opcodes_after_delete(start, end)

    def replace_temp(self, start: int, end: int, text: str):
        if start < 0 or end > len(self.clean) or start > end:
            raise ValueError("invalid replace range")
        # equivalent to delete then insert at start
        old_len = end - start
        self.clean = self.clean[:start] + text + self.clean[end:]
        self.opcodes.append(("replace", start, end, text))
        # adjust tag_records & opcodes: first remove middle, then insert delta
        if old_len > 0:
            self._adjust_tag_records_after_delete(start, end)
            self._adjust_opcodes_after_delete(start, end)
        delta = len(text) - old_len
        if delta != 0:
            # shift records/opcodes after start by delta
            self._shift_tag_records_after_insert(start, delta)
            self._shift_opcodes_after_insert(start, delta)

    # -------------------------
    # Public irreversible operations (_force)
    # -------------------------
    def insert(self, pos: int, text: str):
        if pos < 0 or pos > len(self.clean):
            raise ValueError("insert position out of bounds")
        self.clean = self.clean[:pos] + text + self.clean[pos:]
        L = len(text)
        # For irreversible op we DO NOT append a reversible opcode; instead we mutate existing opcodes
        # shift existing opcode coords (they remain representing tags etc)
        self._shift_tag_records_after_insert(pos, L)
        self._shift_opcodes_after_insert(pos, L)

    def delete(self, start: int, end: int):
        if start < 0 or end > len(self.clean) or start > end:
            raise ValueError("invalid delete range")
        self.clean = self.clean[:start] + self.clean[end:]
        # For irreversible, we must modify existing opcodes: overlapping ones are collapsed or trimmed
        self._adjust_tag_records_after_delete(start, end)
        self._adjust_opcodes_after_delete(start, end)
        # DO NOT append a reversible delete opcode

    def replace(self, start: int, end: int, text: str):
        if start < 0 or end > len(self.clean) or start > end:
            raise ValueError("invalid replace range")
        old_len = end - start
        self.clean = self.clean[:start] + text + self.clean[end:]
        # For irreversible: first adjust for delete part, then shift for inserted length
        if old_len > 0:
            self._adjust_tag_records_after_delete(start, end)
            self._adjust_opcodes_after_delete(start, end)
        delta = len(text) - old_len
        if delta != 0:
            self._shift_tag_records_after_insert(start, delta)
            self._shift_opcodes_after_insert(start, delta)

    # -------------------------
    # Regex helpers (reversible and force versions)
    # -------------------------
    def insert_re_temp(self, pattern: str, repl_text: str, insert_before: bool = True):
        for m in re.finditer(pattern, self.clean):
            pos = m.start() if insert_before else m.end()
            self.insert_temp(pos, repl_text)

    def delete_re_temp(self, pattern: str):
        # delete from right-to-left to avoid offsets changing earlier matches
        for m in reversed(list(re.finditer(pattern, self.clean))):
            self.delete_temp(m.start(), m.end())

    def replace_re_temp(self, pattern: str, repl_text: str):
        for m in reversed(list(re.finditer(pattern, self.clean))):
            self.replace_temp(m.start(), m.end(), repl_text)

    def insert_re(self, pattern: str, repl_text: str, insert_before: bool = True):
        for m in list(re.finditer(pattern, self.clean)):
            pos = m.start() if insert_before else m.end()
            self.insert(pos, repl_text)

    def delete_re(self, pattern: str):
        for m in reversed(list(re.finditer(pattern, self.clean))):
            self.delete(m.start(), m.end())

    def replace_re(self, pattern: str, repl_text: str):
        for m in reversed(list(re.finditer(pattern, self.clean))):
            self.replace(m.start(), m.end(), repl_text)


def get_opcodes_and_print(text, other):
    et = ExtractableText(text)
    et.extract_html()
    opcodes = et.compare(other)
    ops = []
    print("差异比较：")
    for tag, i1, i2, j1, j2 in opcodes:
        if tag == "equal":
            continue
        elif tag == "replace":
            # 替换 clean_a[i1:i2] -> clean_b[j1:j2]
            print(f"replace [{i1}, {i2}]{et.clean[i1:i2]} -> {other[j1:j2]}")
            ops.append(("replace", i1, i2, other[j1:j2], ))
        elif tag == "delete":
            # 删除 clean_a[i1:i2]
            print(f"delete [{i1}]{et.clean[i1:i2]} -> {other[j1:j2]}")
            ops.append(("delete", i1, i2, ""))
        elif tag == "insert":
            # 在 clean_a[i1] 处插入 clean_b[j1:j2]
            print(f"insert [{i1}]{et.clean[i1:i2]} -> {other[j1:j2]}")
            ops.append(("insert", i1, i1, other[j1:j2]))
        else:
            raise ValueError(f"Unknown tag: {tag}")

    return et, ops

def process_insert(et,ops):
    for op,start,end,val in reversed(ops):
        if op == 'insert':
            et.insert(start,val)
    return et.restore()


def load_mdx_src(path):
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.read().splitlines()
    words = []
    word_map = {}
    duplicate_words = {}
    entry = {'links':[]}
    ishead = True
    for line in lines:
        if line.strip() == '</>':
            if 'definition' in entry:
                words.append(entry)
                if entry['head'] in word_map:
                    if not entry['head'] in duplicate_words:
                        duplicate_words[entry['head']]=word_map[entry['head']]
                    duplicate_words[entry['head']].append(entry)
                else:
                    word_map[entry['head']] = [entry]
            entry = {'links':[]}
            ishead = True
        elif ishead:
            entry['head']=line.strip()
            ishead = False
        elif line.startswith('@@@LINK='):
            link = line[len('@@@LINK='):]
            if link in word_map:
                word_map[link]['links'].append(entry['head'])
        else:
            entry['definition'] = entry.get('definition','') + line + '\n'
    return words, duplicate_words


def write_mdx_source(words, output_txt_path):

    # 打开输出文件准备写入
    with open(output_txt_path, 'w', encoding='utf-8') as f_out:
        
        # 遍历JSON数据中的每一个词条
        for word in words:
            if not word['head']:
                continue
            headword = word['head']
            expaneded_words = word.get('links',[])
            
            #if ord('1') <= ord(headword[-1]) <= ord('5'):
            #    expaneded_words.append(headword[:-1])
            #    headword = headword[:-1] + ['¹',  '²', '³', '⁴', '⁵'][ord(headword[-1]) - ord('1')]
            html = word['definition']
            
            f_out.write(headword + '\n')
            f_out.write(html + '\n')
            f_out.write('</>\n')
            
            for expaned_word in expaneded_words:
                expaned_word = expaned_word.strip()
                f_out.write(expaned_word + '\n')
                f_out.write(f'@@@LINK={headword}\n')
                f_out.write('</>\n')

# -------------------------
# 示例与测试
# -------------------------
if __name__ == "__main__":
    en,en_dup = load_mdx_src('Oxford Essential Dictionary.mdx.txt')
    cn,cn_dup = load_mdx_src('total 000.txt')
    cn_map = {e['head'].replace(', ',','):e for e in cn}
    en_map = {}
    for e in en:
        if e['head'] in en_dup:
            for idx,dup in enumerate(en_dup[e['head']],1):
                head = e['head'].replace(', ',',')+'⁰¹²³⁴⁵⁶⁷⁸⁹'[idx]
                en_map[head] = dup
        else:
            en_map[e['head'].replace(', ',',')] = e
    en_set = set(en_map.keys())
    cn_set = set(cn_map.keys())
    matches = en_set.intersection(cn_set)
    en_nomatch = en_set - cn_set
    cn_nomatch = cn_set - en_set
    print(sorted(en_nomatch))
    print(sorted(cn_nomatch))
    for word in matches:
        original = en_map[word]['definition']
        chinese = cn_map[word]['definition']
        print(f'>{word}:')
        #print("原文：",original)
        #print("中文：",chinese)
        et, ops = get_opcodes_and_print(original, chinese)
        print('去格式原文：',et.clean)
        new_text = process_insert(et,ops)
        en_map[word]['definition'] = new_text
        #print('插入后:', new_text)
    write_mdx_source(en,'Oxford Essential Dictionary.cn.mdx.txt')

    
    