from ripemd128 import ripemd128 # type: ignore
from struct import unpack
from flask import Flask, Response, send_file
from pathlib import Path
import xml.etree.ElementTree as ET
import zlib, re, time

class notMdxFileError(Exception):
    pass

class checkSumError(Exception):
    pass

class unSupportError(Exception):
    pass


# 来源 https://github.com/wamich/js-mdx-server
injectionJs = r"""<script>
  // 参考 [ninja33/mdx-server](https://github.com/ninja33/mdx-server)

  document.addEventListener("DOMContentLoaded", () => {
    const audio_type = {
      mp3: "audio/mpeg",
      mp4: "audio/mp4",
      wav: "audio/wav",
      spx: "audio/ogg",
      ogg: "audio/ogg",
    };

    function audio_content_type(ext) {
      return audio_type[ext] || "audio/mpeg";
    }

    const getAudioEl = (() => {
      let audioEl;
      return () => {
        if (audioEl) return audioEl;

        audioEl = document.querySelector("audio");
        if (!audioEl) audioEl = document.createElement("audio");
        return audioEl;
      };
    })();

    // 修复sound链接发音问题
    function fixSound() {
      const soundElements = document.querySelectorAll('a[href^="sound://"]');
      soundElements.forEach((el) => {
        el.addEventListener("click", (e) => {
          e.preventDefault();
          e.stopPropagation();

          const href = el.getAttribute("href");
          const src = href.substring("sound:/".length);

          const audio = getAudioEl();
          audio.setAttribute("src", src);

          const type = audio_content_type(href.slice(-3));
          audio.setAttribute("type", type);

          try {
            audio.play();
          } catch (err) {
            console.error(err);
          }
        });
      });
    }

    const parentWin = window.parent || window.top;

    // 双击跳转查词
    document.addEventListener("dblclick", () => {
      const select = document.getSelection().toString().trim();
      if (/^\w+$/g.test(select)) {
        parentWin.postMessage({ select }, { targetOrigin: "*" });
      }
    });

    // TODO: 等待词典其他的js文件执行结束? 是否存在优化空间
    setTimeout(() => {
      fixSound();
    }, 500);
  });
</script>"""

# 解密
def decrypt(buf: bytearray, key: bytes):
    buflen, keylen = len(buf), len(key)
    out = bytearray(buflen)
    prev = 0x36
    for i in range(buflen):
        m1 = ((buf[i] >> 4) | (buf[i] << 4)) & 0xFF
        m2 = m1 ^ (i&0xFF) ^ key[i % keylen] ^ prev
        prev = buf[i]
        out[i] = m2
    return out

# 解包block
def unpackBlock(buf, encrypted=False, decompSize=None):
    blockType = buf[0:4]
    blockAdler32 = buf[4:8]
    if encrypted:
        data = decrypt(buf[8:], ripemd128(blockAdler32 + b"\x95\x36\x00\x00"))
    else:
        data = buf[8:]
    if blockType == b"\x00\x00\x00\x00":
        origData = data
    elif blockType == b"\x01\x00\x00\x00":
        # lzo压缩，先不管
        raise unSupportError("不支持lzo压缩")
    elif blockType == b"\x02\x00\x00\x00":
        origData = zlib.decompress(data, bufsize=decompSize if decompSize is not None else 65536)
    else:
        print("blockType:", blockType.hex())
    if zlib.adler32(origData) != unpack(">I", blockAdler32)[0]:
        print("block校验出错")
        raise checkSumError("校验错误")
    return origData

def readmdx(mdxPath: str, queryWord: str, history=None, printLog=False, indexIndexList=None, onlyGetIndexIndexList=False, getAllEntryName=False):
    # print(queryWord)
    queryWord = queryWord.split("#", 1)[0]
    if printLog:
        print("query:", queryWord)
        timeList = [] # 用于统计耗时
        timeList.append(time.perf_counter())
    if history is None:
        history = []
    f = open(mdxPath, "rb")
    # 读取头部xml长度
    f.seek(0)
    headerLen = unpack(">I", f.read(4))[0]
    if headerLen > 1e7:
        raise notMdxFileError("太大的头部，可能不是mdx文件")
    # 读取头部
    headerXmlHex = f.read(headerLen)
    # 读取alder32
    headerHexChksum = f.read(4)
    if unpack("<I", headerHexChksum)[0] != zlib.adler32(headerXmlHex):
        raise checkSumError("不是mdx文件或已损坏")
    headerXml = ET.fromstring(headerXmlHex[:-2].decode("utf-16"))
    mdxVer = headerXml.get("RequiredEngineVersion")
    mdxEcp = headerXml.get("Encrypted")
    encoding = headerXml.get("Encoding")
    # 确定终止符的长度
    if encoding == "UTF-8" or encoding == "GBK":
        termLen, term = 1, b"\x00"
    elif encoding == "UTF-16" or encoding == "":
        termLen, term = 2, b"\x00\x00"
    else:
        raise unSupportError(f"不支持的编码: {encoding}")
    # 如果encoding为""，则为mdd
    isMdd = False
    if encoding == "":
        encoding, isMdd, queryWord = "utf-16", True, queryWord.replace("/", "\\")
        # 如果开头没\要补上
        if queryWord[0] != "\\":
            queryWord = "\\" + queryWord.lower()
    if printLog:
        print("header len:", headerLen)
        print("header hex alder32:", headerHexChksum.hex().upper())
        print("version:", mdxVer)
        print("encrypted:", mdxEcp)
        print("encoding:", encoding)
        print("----------")
    if mdxEcp == "3" or mdxEcp == "1":
        raise unSupportError(f"此加密不支持: {mdxEcp}")
    fastEcp = True if mdxEcp == "2" else False
    keyCaseSensitive = False if headerXml.get("KeyCaseSensitive") == "No" else True
    if printLog:
        timeList.append(time.perf_counter())
        print("readHeaderTime", timeList[-1]-timeList[-2])
    if int(mdxVer[0]) == 2:
        indexSectionHeader = f.read(40)
        indexHeaderChksum = f.read(4)
        if zlib.adler32(indexSectionHeader) != unpack(">I", indexHeaderChksum)[0]:
            print("indexSectionHeader 校验不通过")
            raise checkSumError("index header 校验不通过")
        # 读取index Block 数量
        indexBlockNum = unpack(">Q", indexSectionHeader[:8])[0]
        # 读取entry数量
        totalEntryNum = unpack(">Q", indexSectionHeader[8:16])[0]
        indexIndexOrigSize = unpack(">Q", indexSectionHeader[16:24])[0]
        indexIndexBlockSize = unpack(">Q", indexSectionHeader[24:32])[0]
        indexBlocksTotalSize = unpack(">Q", indexSectionHeader[32:40])[0]
        if printLog:
            print("entries:", totalEntryNum)
            print("index block num:", indexBlockNum)
            print("indexIndexOrigSize:", indexIndexOrigSize)
            print("indexIndexBlockSize:", indexIndexBlockSize)
            print("indexBlocksTotalSize:", indexBlocksTotalSize)
            print("indexHeaderChksum:", indexHeaderChksum.hex().upper())
        indexIndexBlockStart = f.tell()
        # 解析indexIndex，如果传入，则不用解析
        if indexIndexList is None:
            indexIndex = unpackBlock(f.read(indexIndexBlockSize), fastEcp)
            # printSomeHex(indexIndex, 32, 8)
            # print(indexIndex.hex().upper())
            indexIndexList, indexBlocksTotalSize, pos = [], 0, 0
            while pos < len(indexIndex):
                #  读取entry数
                entryNum = unpack(">Q", indexIndex[pos:pos+8])[0]
                pos = pos + 8
                # 读取起始单词的长度
                entryStartLen = unpack(">H", indexIndex[pos:pos+2])[0]
                pos = pos + 2
                # 读取起始单词
                entryStart = indexIndex[pos:pos+entryStartLen*termLen].decode(encoding)
                pos = pos + termLen * (entryStartLen+1)
                # 读取终止单词的长度
                entryEndLen = unpack(">H", indexIndex[pos:pos+2])[0]
                pos = pos + 2
                # 读取终止单词
                entryEnd = indexIndex[pos:pos+entryEndLen*termLen].decode(encoding)
                pos = pos + termLen * (entryEndLen+1)
                # 压缩大小
                indexBlockCompSize = unpack(">Q", indexIndex[pos:pos+8])[0]
                indexBlocksTotalSize = indexBlocksTotalSize + indexBlockCompSize
                pos = pos + 8
                # 解压大小
                keyBlockSizeDecomp = unpack(">Q", indexIndex[pos:pos+8])[0]
                pos = pos + 8
                indexIndexList.append({"entryNum": entryNum, "entryStart": entryStart, "entryEnd": entryEnd, "blockSize": indexBlockCompSize, "decompSize": keyBlockSizeDecomp})
        if printLog:
            print("indexIndexList:", indexIndexList[:4], "... ", len(indexIndexList), "in total")
            timeList.append(time.perf_counter())
            print("readIndexIndexTime", timeList[-1]-timeList[-2])
        if onlyGetIndexIndexList:
            return indexIndexList
        indexBlockStart = indexIndexBlockStart + indexIndexBlockSize
        recordSectionStart = indexBlockStart + indexBlocksTotalSize
        if getAllEntryName:
            allEntryList = []
            for i in range(len(indexIndexList)):
                index, pos = unpackBlock(f.read(indexIndexList[i]["blockSize"]), decompSize=indexIndexList[i]["decompSize"]), 0
                while pos < len(index):
                    entry, pos = bytearray(), pos + 8
                    while True:
                        if index[pos:pos+termLen] == term:
                            pos = pos + termLen
                            break
                        else:
                            entry.extend(index[pos:pos+termLen])
                            pos = pos + termLen
                    entry = entry.decode(encoding, "replace")
                    allEntryList.append(entry)
            return allEntryList
        f.seek(recordSectionStart)
        recordBlocksNum = unpack(">Q", f.read(8))[0]
        totalEntryNum2 = unpack(">Q", f.read(8))[0]
        recordIndexLen = unpack(">Q", f.read(8))[0]
        recordBlocksTotalSize = unpack(">Q", f.read(8))[0]
        recordIndex = f.read(recordIndexLen)
        recordBlockStart = f.tell()
        # 解析record index
        recordIndexList, pos = [], 0
        while pos < len(recordIndex):
        # while pos < 256:
            blockSize = unpack(">Q", recordIndex[pos:pos+8])[0]
            pos = pos + 8
            decompSize = unpack(">Q", recordIndex[pos:pos+8])[0]
            pos = pos + 8
            recordIndexList.append({"blockSize": blockSize, "decompSize": decompSize})
        if totalEntryNum != totalEntryNum2:
            raise notMdxFileError("文件中记录的两个词条数不相等")
        if printLog:
            print("recordIndexListLen", len(recordIndexList))
            timeList.append(time.perf_counter())
            print("readRecordIndexTime", timeList[-1]-timeList[-2])
            print("ok, read indexIndex")
        # 寻找 index block 并在里面寻找所有词条
        targetIndexBlockStart, index, targetRecordStarts, targetRecordEnds = indexBlockStart, bytearray(), [], []
        searchKey = queryWord.lower() if isMdd else re.sub(r"\W", "", queryWord.lower())
        searchKeyLen = len(searchKey)
        if printLog:
            print("searchKey", searchKey ,"keyCaseSensitive", keyCaseSensitive)
        # searchedEntry = []
        for i in range(len(indexIndexList)):
            if indexIndexList[i]["entryStart"] <= searchKey <= indexIndexList[i]["entryEnd"]:
                # indexBlock = getBlock(targetIndexBlockStart, indexIndexList[i]["blockSize"])
                f.seek(targetIndexBlockStart)
                index, pos = unpackBlock(f.read(indexIndexList[i]["blockSize"])), 0
                # 在index block内寻找所有词条
                while pos < len(index):
                    decompedOffset = unpack(">Q", index[pos:pos+8])[0]
                    pos = pos + 8
                    entry = bytearray()
                    while True:
                        if index[pos:pos+termLen] == term:
                            pos = pos + termLen
                            break
                        else:
                            entry.extend(index[pos:pos+termLen])
                            pos = pos + termLen
                    entry = entry.decode(encoding)
                    # searchedEntry.append(entry)
                    # print(entry)
                    if ( keyCaseSensitive and queryWord == entry ) or ( not keyCaseSensitive and queryWord == entry.lower() ):
                        # 加入history
                        history.append(entry)
                        targetRecordStarts.append(decompedOffset)
                        # 获取下一个词条的start，即为这个词条的end，如果这是当前index的最后一个，还需获取下一个index，如果这是最后一个index，从总record index获取总解压后的大小
                        if pos < len(index):
                            targetRecordEnds.append(unpack(">Q", index[pos:pos+8])[0])
                        elif pos == len(index) and i != len(indexIndexList)-1:
                            # f.seek(targetIndexBlockStart+indexIndexList[i+1]["blockSize"])
                            nextIndex = unpackBlock(f.read(indexIndexList[i+1]["blockSize"]))
                            targetRecordEnds.append(unpack(">Q", nextIndex[:8])[0])
                        else:
                            targetRecordEnds.append(sum(item["decompSize"] for item in recordIndexList))
                if printLog:
                    print("got indexBlock:", i)
                # print(index.hex().upper())
            targetIndexBlockStart = targetIndexBlockStart + indexIndexList[i]["blockSize"]
        # print(searchedEntry)
        if printLog:
            timeList.append(time.perf_counter())
            print("searchKeyTime", timeList[-1]-timeList[-2])
            print("targetRecordStarts", targetRecordStarts)
        # 在record block内找指定内容
        allContent = bytearray()
        for i in range(len(targetRecordStarts)):
            targetFilePos, j, added = recordBlockStart, 0, 0
            while True:
                if targetRecordStarts[i] < added + recordIndexList[j]["decompSize"]:
                    break
                added = added + recordIndexList[j]["decompSize"]
                targetFilePos = targetFilePos + recordIndexList[j]["blockSize"]
                j = j + 1
            # 读内容
            f.seek(targetFilePos)
            decompedRecord = unpackBlock(f.read(recordIndexList[j]["blockSize"]))
            pos, content = targetRecordStarts[i] - added, bytearray()
            while True:
                if len(decompedRecord) >= targetRecordStarts[i]-targetRecordEnds[i]:
                    break
                else:
                    decompedRecord += unpackBlock(f.read(recordIndexList[j+1]["blockSize"]))
                    j = j + 1
            # mdd不用截掉终止符
            if isMdd:
                content = decompedRecord[targetRecordStarts[i]-added:targetRecordEnds[i]-added]
            else:
                content = decompedRecord[targetRecordStarts[i]-added:targetRecordEnds[i]-added][:-termLen]
            if content.decode(encoding, "replace")[:8*termLen] == "@@@LINK=":
                # 内容可能以换行符等结尾，要去掉
                linkWord = content[8:].decode(encoding).strip()
                if linkWord == queryWord or linkWord in history:
                    continue
                # 使用新的history列表而不是在前面append，允许不同分支的词条链接到同一个词条
                # content = readmdx(mdxPath, linkWord, history+[word])
                # 不显示重复link
                content = readmdx(mdxPath, linkWord, history, printLog=printLog, indexIndexList=indexIndexList)
            # 如果是mdd 不要history数据
            if isMdd:
                allContent = content
            else:
                allContent.extend(content)
        f.close()
        if printLog:
            timeList.append(time.perf_counter())
            print("getContentTime", timeList[-1]-timeList[-2])
            print("totalTime", timeList[-1]-timeList[0])
        return allContent
    else:
        raise unSupportError(f"不支持的mdx版本: {mdxVer}")

htmlText = """<!DOCTYPE html>
<html lang="zh-CN">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body>"""

dictDir = Path("G:/A/lm5").resolve()

mdxFiles, mddFiles = list(dictDir.glob("*.mdx")), list(dictDir.glob("*.mdd"))
mdxIndexIndexLists = {}
for f in mdxFiles:
    mdxIndexIndexLists[f] = readmdx(f, " ", onlyGetIndexIndexList=True)
for f in mddFiles:
    mdxIndexIndexLists[f] = readmdx(f, " ", onlyGetIndexIndexList=True)

app = Flask(__name__)
@app.route("/<path:word>")
def getContent(word):
    # print("query:", word)
    blocked = ["LM5Switch.js"]
    blocked = []
    if word in blocked:
        return "blocked file", 403
    path = (dictDir / word).resolve()
    if not path.is_relative_to(dictDir):
        raise ValueError("非法路径访问")
    if path.exists() and path.is_file():
        return send_file(path)
    try:
        # 出现/先从mdd找
        if "/" in word:
            for mdd in mddFiles:
                # content = readmdx(mdd, word, indexIndexList=mdxIndexIndexLists.get(mdd), printLog=True)
                content = readmdx(mdd, word, indexIndexList=mdxIndexIndexLists.get(mdd))
                if content:
                    return Response(bytes(content), mimetype="audio/mpeg")
                    return Response(bytes(content))
        else:
            for mdx in mdxFiles:
                print("Query: ", word)
                content = readmdx(mdx, word,printLog=False, indexIndexList=mdxIndexIndexLists.get(mdx))
                if content:
                    return htmlText + content.decode("utf-8", "replace").replace("entry://", "/").replace("file://", "/") + injectionJs + "</body></html>"
            for mdd in mddFiles:
                content = readmdx(mdd, word, indexIndexList=mdxIndexIndexLists.get(mdd))
                if content:
                    return Response(bytes(content), mimetype="audio/mpeg")
                    return Response(bytes(content))
        return "Entry or file not found", 404
    except Exception as e:
        print(e)
        return str(e), 500


if __name__ == '__main__':
    print(f"Service running at http://127.0.0.1:5003/word")
    app.run(host='0.0.0.0', port=5003)
    # app.run(host='0.0.0.0', port=5003, debug=True)