# -*- coding: utf-8 -*-
import os
import sys
from google import genai
from google.genai import types  # 用来构造图片 Part
import re # Import re module for regular expressions
from PIL import Image, ImageDraw # 导入 Image 和 ImageDraw
import io
import json # Import json module for parsing JSON responses

from dataclasses import dataclass

@dataclass
class BoundingBox:
    box_2d: list[int]  # [y_min, x_min, y_max, x_max]
    text_content: str

# 让 stdout 尽量用 UTF-8，避免中文打印触发 ASCII 错误
try:
    sys.stdout.reconfigure(encoding="utf-8")
except Exception:
    pass

API_KEY = "AIzaSyBuaw"
client = genai.Client(api_key=API_KEY)

# 用于处理图像和文本的多模态聊天类
class ImageChat:
    def __init__(self, image_path, client=None, model="gemini-2.5-flash", mime_type="image/png"):
        self.image_path = image_path
        self.model = model
        self.client = client or client
        self.mime_type = mime_type
        self.image_width = None
        self.image_height = None

    def ask(self, question: str, response_schema=None) -> str:
        """
        向 Gemini 模型提问，并传入图片作为上下文。
        读取图片为 bytes，然后用 Part.from_bytes 传给模型。
        """
        # 读取图片为 bytes，然后用 Part.from_bytes 传给模型。[web:6][web:44][web:172]
        if not self.image_path or not os.path.isfile(self.image_path):
            print(f"提示: 未指定有效的图片文件或文件不存在：{self.image_path}")
            raise FileNotFoundError(f"未找到图片文件: {self.image_path if self.image_path else '(未指定)'}")
        with open(self.image_path, "rb") as f:
            image_bytes = f.read()

        try:
            from PIL import Image
            import io
            image = Image.open(io.BytesIO(image_bytes))
            self.image_width = image.width
            self.image_height = image.height
            print(f"DEBUG: 图片原始尺寸: 宽度={self.image_width}, 高度={self.image_height}")
        except ImportError:
            print("警告: Pillow 库未安装。无法获取图片尺寸。请运行 'pip install Pillow' 安装。")
        except Exception as e:
            print(f"DEBUG: 获取图片尺寸时发生错误: {e}")

        # 将图片尺寸信息添加到 question 中
        if self.image_width is None or self.image_height is None:
            # Re-read image if dimensions were not captured (e.g., during __init__)
            try:
                from PIL import Image
                import io
                with open(self.image_path, "rb") as f:
                    image_bytes_temp = f.read()
                image = Image.open(io.BytesIO(image_bytes_temp))
                self.image_width = image.width
                self.image_height = image.height
                print(f"DEBUG: 重新获取图片原始尺寸: 宽度={self.image_width}, 高度={self.image_height}")
            except Exception as e:
                print(f"警告: 无法重新获取图片尺寸: {e}")
                self.image_width = 1000 # Fallback to default if unable to get dimensions
                self.image_height = 1000 # Fallback to default

        if self.image_width is not None and self.image_height is not None:
            dimension_info = f"图片原始尺寸：宽度={self.image_width}, 高度={self.image_height}. "
            question = dimension_info + question

        image_part = types.Part.from_bytes(
            data=image_bytes,
            mime_type=self.mime_type,
        )

        generation_config = {
            "temperature": 0.5,
        }
        if response_schema:
            generation_config["response_mime_type"] = "application/json"
            generation_config["response_schema"] = response_schema

        stream = self.client.models.generate_content_stream(
            model=self.model,
            contents=[
                image_part,
                question,
            ],
        )

        chunks = []
        for chunk in stream:
            if getattr(chunk, "text", None):
                chunks.append(chunk.text)

        return "".join(chunks)


# 辅助函数：将标准化坐标转换为像素坐标 (假定模型输出为 0-1000 范围)
def normalize_to_pixel(coord: float, dimension: int) -> int:
    return int((coord / 1000) * dimension)

# 辅助函数：在图片上绘制边界框
def draw_bounding_boxes_on_image(
    original_image_path: str,
    output_image_path: str,
    word_head_boxes: list[tuple], # 列表，每个元素为 (word_text, x1, y1, x2, y2)
    column_boxes: list[tuple], # 列表，每个元素为 (column_name, x1, y1, x2, y2)
):
    try:
        image = Image.open(original_image_path).convert("RGB")
        draw = ImageDraw.Draw(image)

        # 绘制词头边界框 (红色)
        for text, x1, y1, x2, y2 in word_head_boxes:
            draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
            # 可选：在边界框上方添加文本标签
            # draw.text((x1, y1 - 10), text, fill="red", font=ImageFont.truetype("arial.ttf", 15))

        # 绘制栏目边界框 (蓝色)
        for name, x1, y1, x2, y2 in column_boxes:
            draw.rectangle([x1, y1, x2, y2], outline="blue", width=2)
            # 可选：在边界框上方添加文本标签
            # draw.text((x1, y1 - 10), name, fill="blue", font=ImageFont.truetype("arial.ttf", 15))

        image.save(output_image_path)
        print(f"DEBUG: 带有边界框的图片已保存到: {output_image_path}")
    except ImportError:
        print("警告: Pillow 库未安装。无法生成带有边界框的图片。请运行 'pip install Pillow' 安装。")
    except Exception as e:
        print(f"警告: 绘制边界框图片时发生错误: {e}")

# 执行 OCR 识别，将图片中的文本提取并写入到文件
def write_ocr_text(image_path: str, output_path: str, mime_type: str = "image/png") -> str:
    system_prompt = """
    你是一个OCR专家，请仔细分析提供的词典页面图片。
    请识别页面中的所有词头，并为每个词头提供其文本内容和标准化边界框坐标。标准化坐标的范围是 0 到 1000。
    请按照阅读顺序输出：从左到右、从上到下、从左栏到右栏。
    
    必须严格按照以下 JSON 格式输出，每个对象必须包含 "box_2d" 和 "text_content" 两个字段：
    [
      {"box_2d": [y_min, x_min, y_max, x_max], "text_content": "词头文本"},
      {"box_2d": [y_min, x_min, y_max, x_max], "text_content": "词头文本"},
      ...
    ]
    
    仅输出 JSON 格式，不包含任何额外文本、代码块标记或说明文字。
    """

    print("使用的 prompt 已硬编码到程序中。")

    chat = ImageChat(image_path, client=client, mime_type=mime_type)
    
    # 定义 response_schema
    response_schema = {
        "type": "array",
        "items": {
            "type": "object",
            "properties": {
                "box_2d": {"type": "array", "items": {"type": "integer"}},
                "text_content": {"type": "string"}
            },
            "required": ["box_2d", "text_content"]
        }
    }

    model_response_text = chat.ask(system_prompt, response_schema=response_schema)

    # 打印原始模型响应（用于调试）
    print("\n--- 原始模型响应 ---\n")
    print(model_response_text)
    print("\n----------------------\n")

    # 解析 JSON 响应
    try:
        # 尝试提取 JSON 内容，去除 ```json\n 和 \n``` 标记
        json_match = re.search(r"```json\n(.*)\n```", model_response_text, re.DOTALL)
        if json_match:
            json_string = json_match.group(1)
        else:
            json_string = model_response_text # 如果没有代码块标记，则直接使用原始响应

        model_response = json.loads(json_string)
        word_head_boxes_raw = [BoundingBox(**item) for item in model_response]
    except json.JSONDecodeError as e:
        print(f"错误: 无法解析模型响应为 JSON: {e}")
        print(f"原始响应文本: {model_response_text}")
        return "" # 或者根据需要处理错误

    parsed_output = []
    word_head_pixel_boxes = []  # 用于存储词头边界框 (text, x1, y1, x2, y2)
    column_pixel_boxes = []  # 用于存储栏目边界框 (name, x1, y1, x2, y2)

    # 获取图片尺寸
    image_width = chat.image_width if chat.image_width is not None else 1
    image_height = chat.image_height if chat.image_height is not None else 1

    parsed_output.append("词头\tX\tY")
    for bbox in word_head_boxes_raw:
        text_content = bbox.text_content
        ymin_norm, xmin_norm, ymax_norm, xmax_norm = bbox.box_2d

        x1_pixel = normalize_to_pixel(xmin_norm, image_width)
        y1_pixel = normalize_to_pixel(ymin_norm, image_height)
        x2_pixel = normalize_to_pixel(xmax_norm, image_width)
        y2_pixel = normalize_to_pixel(ymax_norm, image_height)

        word_head_pixel_boxes.append((text_content, x1_pixel, y1_pixel, x2_pixel, y2_pixel))
        parsed_output.append(f"{text_content}\t{x1_pixel}\t{y1_pixel}")

    # 栏目信息的处理，暂时简化或移除，因为模型现在只返回词头边界框
    # 如果需要栏目信息，需要在 prompt 和 schema 中明确要求
    parsed_output.append("\n===页面信息===")
    parsed_output.append(f"图片大小: 宽度={image_width}, 高度={image_height}")
    
    final_output = "\n".join(parsed_output)

    with open(output_path, "w", encoding="utf-8") as f:
        f.write(final_output)

    # 生成输出图片路径
    output_image_path_base, ext = os.path.splitext(output_path)
    output_image_with_boxes_path = f"{output_image_path_base}_boxes.png"

    # 调用函数绘制边界框
    draw_bounding_boxes_on_image(
        image_path,
        output_image_with_boxes_path,
        word_head_pixel_boxes,
        column_pixel_boxes,
    )

    return final_output


if __name__ == "__main__":
    # 这里换成图片路径，比如 PNG/JPG
    input_image = r"Z:\1.jpg"  # or r"Z:\2_page01.jpg"
    output_txt = r"Z:\2_page01.txt"
    # 对应修改 mime_type
    write_ocr_text(input_image, output_txt, mime_type="image/png")
