import os
import sys
import numpy as np
import matplotlib.pyplot as pyplot
from matplotlib.image import imread
import imageio


def gra2png(filename, output=None, show_image=True):
    filesize = os.path.getsize(filename)
    with open(filename, "br") as gra:
        gra.read(3)
        height = int.from_bytes(gra.read(1), byteorder="big")
        width = int((filesize - 4) / 2 / height)
        print(f"Size: {height}x{width}")
        rgb = np.zeros((height, width, 3), dtype=np.uint8)

        for h in range(height):
            for w in range(width):
                data = int.from_bytes(gra.read(2), byteorder="big")
                rgb[h][w][0] = 8 * ((data & 0b1111100000000000) >> 11)
                rgb[h][w][1] = 8 * ((data & 0b0000011111000000) >> 6)
                rgb[h][w][2] = 8 * ((data & 0b0000000000011111) >> 0)

    print(f"Image data shape: {rgb.shape}, dtype: {rgb.dtype}")
    pyplot.imshow(rgb, interpolation="nearest")
    pyplot.set_cmap("hot")
    pyplot.axis("off")
    if output != None:
        print(f"Attempting to save image to: {output}")
        try:
            imageio.imwrite(output, rgb)
            print("Image saved successfully.")
        except Exception as e:
            print(f"Error saving image: {e}")
    if show_image:
        pyplot.show()


def gra2png_mono(filename, output=None, show_image=True):
    filesize = os.path.getsize(filename)
    if filesize < 4:
        print(f"Error: File '{filename}' is too small.")
        return

    with open(filename, "br") as gra:
        header = gra.read(4)

    header_width = int.from_bytes(header[1:2], byteorder="big")
    height = int.from_bytes(header[3:4], byteorder="big")

    if height == 0:
        print(f"Error: File '{filename}' has a height of 0.")
        return

    calculated_width = int((filesize - 4) * 8 / height)
    bytes_per_row = (filesize - 4) / height

    # Heuristic: Type A has width in header and even bytes per row.
    is_type_a = (
        header_width == calculated_width
        and bytes_per_row == int(bytes_per_row)
        and bytes_per_row % 2 == 0
    )

    if is_type_a:
        print(f"-> Detected 'Type A' monochrome format (e.g., ICON24.GRA).")
        width = header_width
    else:
        print(f"-> Detected 'Type B' monochrome format (e.g., LABEL.GRA).")
        width = calculated_width

    print(f"Size: {height}x{width}")
    rgb = np.zeros((height, width), dtype=np.uint8)

    with open(filename, "br") as gra:
        gra.seek(4)  # Move past header

        if is_type_a:
            # Logic for Type A: read in 2-byte chunks
            for h in range(height):
                for w in range(0, width, 16):
                    data = int.from_bytes(gra.read(2), byteorder="big")
                    for bit in range(16):
                        rgb[h][w + 15 - bit] = (data & (1 << bit)) >> bit
        else:
            # Logic for Type B: read in 1-byte chunks
            if bytes_per_row != width / 8:
                print(
                    "  -> Error: Width does not match data size for Type B. Aborting."
                )
                return
            for h in range(height):
                for w_byte in range(int(bytes_per_row)):
                    byte_data = int.from_bytes(gra.read(1), byteorder="big")
                    for bit in range(8):
                        pixel_index = w_byte * 8 + bit
                        if pixel_index < width:
                            pixel_value = (byte_data >> (7 - bit)) & 1
                            rgb[h][pixel_index] = pixel_value

    print(f"Image data shape: {rgb.shape}, dtype: {rgb.dtype}")
    pyplot.imshow(rgb, interpolation="nearest", cmap="gray")
    pyplot.axis("off")
    if output != None:
        print(f"Attempting to save image to: {output}")
        try:
            imageio.imwrite(output, rgb * 255)
            print("Image saved successfully.")
        except Exception as e:
            print(f"Error saving image: {e}")
    if show_image:
        pyplot.show()


def png2gra(filename, output=None):
    png = imread(filename)
    if not output:
        output = filename.split(".")[0] + ".gra"
    gra = open(output, "bw")
    gra.write(b"\x21\xe0\x00\xa4")
    for h in range(164):
        for w in range(480):
            rgb = [0] * 3
            for chnn in range(3):
                rgb[chnn] = int(png[h][w][chnn] * 255) // 8
            data = (rgb[0] << 11) + (rgb[1] << 6) + (rgb[2])
            gra.write(int.to_bytes(data, length=2, byteorder="big"))
    gra.close()
    gra2png(output)


def batch_process_gra(input_dir, output_dir):
    print(f"Starting batch processing...")
    print(f"Input directory: {input_dir}")
    print(f"Output directory: {output_dir}")
    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.lower().endswith(".gra"):
                input_file_path = os.path.join(root, file)
                relative_path = os.path.relpath(root, input_dir)
                current_output_dir = os.path.join(output_dir, relative_path)
                os.makedirs(current_output_dir, exist_ok=True)
                base_name = os.path.splitext(file)[0]
                # Color version
                output_color_path = os.path.join(
                    current_output_dir, f"{base_name}.color.png"
                )
                print(f"Processing (color): {input_file_path} -> {output_color_path}")
                try:
                    gra2png(input_file_path, output_color_path, show_image=False)
                except Exception as e:
                    print(f"  -> Error processing as color: {e}")
                # Mono version
                output_mono_path = os.path.join(
                    current_output_dir, f"{base_name}.mono.png"
                )
                print(f"Processing (mono): {input_file_path} -> {output_mono_path}")
                try:
                    gra2png_mono(input_file_path, output_mono_path, show_image=False)
                except Exception as e:
                    print(f"  -> Error processing as mono: {e}")
    print("Batch processing complete.")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python gra_reader.py <command> ...")
        print("Commands: batch, gra2png, gra2png_mono, png2gra")
        sys.exit(1)
    command = sys.argv[1]
    if command == "batch":
        if len(sys.argv) != 4:
            print(
                "Usage: python gra_reader.py batch <input_directory> <output_directory>"
            )
            sys.exit(1)
        input_dir = sys.argv[2]
        output_dir = sys.argv[3]
        batch_process_gra(input_dir, output_dir)
    elif command in ["gra2png", "gra2png_mono", "png2gra"]:
        if len(sys.argv) < 3:
            print(f"Usage: python gra_reader.py {command} <input_file> [output_file]")
            sys.exit(1)
        input_file = sys.argv[2]
        output_file = None
        if len(sys.argv) > 3:
            output_file = sys.argv[3]
        if command == "gra2png":
            gra2png(input_file, output_file)
        elif command == "gra2png_mono":
            gra2png_mono(input_file, output_file)
        elif command == "png2gra":
            png2gra(input_file, output_file)
    else:
        print(f"Unknown command: {command}")
        sys.exit(1)
