import cv2
import numpy as np
import os
import platform
from datetime import datetime

# ---------- CONFIG ----------
DISPLAY_W, DISPLAY_H = 1920, 1080

# Left camera preview: 3:2 area, vertically centered
CAM_W, CAM_H = 1500, 1000
CAM_X, CAM_Y = 0, (DISPLAY_H - CAM_H) // 2

# Right panel: exactly assets size, vertically centered
IMG_W, IMG_H = 420, 840
IMG_X, IMG_Y = CAM_W, (DISPLAY_H - IMG_H) // 2

IMG1_PATH, IMG2_PATH = "img1.png", "img2.png"
CAPTURE_DIR = "captures"
os.makedirs(CAPTURE_DIR, exist_ok=True)

A5_W, A5_H = 1748, 2480
MIRROR_PREVIEW = False
COUNTDOWN = 3

DEBUG = False  # set True to print raw key codes as you press keys

# ---------- HELPERS ----------
def center_crop_to_aspect(img, aw=3, ah=2):
    h, w = img.shape[:2]
    target = aw / ah
    cur = w / h
    if cur > target:  # too wide
        new_w = int(h * target)
        x0 = (w - new_w) // 2
        return img[:, x0:x0 + new_w]
    else:
        new_h = int(w / target)
        y0 = (h - new_h) // 2
        return img[y0:y0 + new_h, :]

def load_fixed(path, w, h):
    if not os.path.exists(path):
        return np.zeros((h, w, 3), np.uint8)
    img = cv2.imread(path)
    if img is None:
        return np.zeros((h, w, 3), np.uint8)
    return cv2.resize(img, (w, h))

def add_border_and_text(img):
    h, w = img.shape[:2]
    bordered = cv2.copyMakeBorder(img, 8, 8, 8, 8, cv2.BORDER_CONSTANT, value=(255, 255, 255))
    bordered = cv2.copyMakeBorder(bordered, 4, 4, 4, 4, cv2.BORDER_CONSTANT, value=(0, 0, 0))

    text = "256bits.com.au"
    font_scale = max(0.8, w / 1000.0) / 2
    font = cv2.FONT_HERSHEY_SIMPLEX
    outline_thickness = max(2, int(font_scale * 4))
    (tw, th), _ = cv2.getTextSize(text, font, font_scale, outline_thickness)
    x = bordered.shape[1] - tw - 20
    y = bordered.shape[0] - 20
    cv2.putText(bordered, text, (x, y), font, font_scale, (255, 255, 255), outline_thickness, cv2.LINE_AA)
    fg_thickness = max(1, int(font_scale * 2))
    cv2.putText(bordered, text, (x, y), font, font_scale, (0, 0, 0), fg_thickness, cv2.LINE_AA)
    return bordered

def save_capture(img):
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out = os.path.join(CAPTURE_DIR, f"cap_{ts}.png")
    cv2.imwrite(out, img)
    return out

def make_a5(imgpath):
    img = cv2.imread(imgpath)
    if img is None:
        return None
    ih, iw = img.shape[:2]
    scale_orig = min(A5_W / iw, A5_H / ih)
    scale_rot = min(A5_W / ih, A5_H / iw)
    if scale_rot > scale_orig:
        img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
        scale = scale_rot
    else:
        scale = scale_orig
    new_w = max(1, int(img.shape[1] * scale))
    new_h = max(1, int(img.shape[0] * scale))
    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
    canvas = np.ones((A5_H, A5_W, 3), np.uint8) * 255
    x0 = (A5_W - new_w) // 2
    y0 = (A5_H - new_h) // 2
    canvas[y0:y0 + new_h, x0:x0 + new_w] = resized
    out = os.path.join(CAPTURE_DIR, f"print_{datetime.now().strftime('%H%M%S')}.png")
    cv2.imwrite(out, canvas)
    return out

def print_file(path):
    sys_name = platform.system().lower()
    if 'win' in sys_name:
        os.startfile(path, "print")
    else:
        os.system(f"lpr '{path}'")

def get_key(delay=1):
    try:
        k = cv2.waitKeyEx(delay)
    except:
        k = cv2.waitKey(delay)
    return int(k)  # keep as signed int for easier printing/comparison

# ---------- KEY MAPPINGS ----------
# Best-effort mapping that accepts:
#  - ASCII digits ('1','2','3','9') — common when NumLock is ON (can't tell top-row vs numpad)
#  - known X11/OpenCV extended keypad codes (when NumLock is OFF) such as 335,262,338,339
#  - Windows VK_NUMPAD range (96..105) if it appears (harmless on linux)
#
# If DEBUG=True, press keys and note the printed integer codes; we'll patch the sets if your Pi uses different codes.

# known extended codes seen on many Linux setups (numpad when NumLock OFF)
KP_EXT = {
    1: {335, 65436},   # KP_End / other variations
    2: {262, 65433},   # KP_Down
    3: {338, 65435},   # KP_PgDn
    9: {339, 65434},
}

# Windows VK_NUMPAD codes (96..105). Including these is harmless on Pi.
VK_NUMPAD = set(range(96, 106))

def is_numpad_key(k, n):
    if k is None:
        return False
    ch = k & 0xFF
    # ASCII digit (covers top-row OR numpad when NumLock ON)
    if ch == ord(str(n)):
        return True
    # Windows virtual-key numpad (96..105)
    if k in VK_NUMPAD and (k - 96) == n:
        return True
    # Known extended keypad codes (NumLock OFF)
    if k in KP_EXT.get(n, set()):
        return True
    return False

def is_key_1(k): return is_numpad_key(k, 1)
def is_key_2(k): return is_numpad_key(k, 2)
def is_key_3(k): return is_numpad_key(k, 3)
def is_key_9(k): return is_numpad_key(k, 9)

# ---------- CAMERA ----------
cap = cv2.VideoCapture(0, cv2.CAP_V4L2)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)

cv2.namedWindow("Photo Booth", cv2.WND_PROP_FULLSCREEN)
cv2.setWindowProperty("Photo Booth", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)

mode = "main"
right_img = IMG1_PATH
last_cap_path = None
countdown_timer = 0
post_countdown_mode = "review"

# ---------- MAIN LOOP ----------
while True:
    ret, frame = cap.read()
    if not ret:
        break
    if MIRROR_PREVIEW:
        frame = cv2.flip(frame, 1)

    screen = np.zeros((DISPLAY_H, DISPLAY_W, 3), np.uint8)
    cam_crop = center_crop_to_aspect(frame, 3, 2)
    cam_view = cv2.resize(cam_crop, (CAM_W, CAM_H))

    k = get_key(1)
    if DEBUG and k != -1:
        print("KEY:", k)

    if mode == "main":
        screen[CAM_Y:CAM_Y + CAM_H, CAM_X:CAM_X + CAM_W] = cam_view
        screen[IMG_Y:IMG_Y + IMG_H, IMG_X:IMG_X + IMG_W] = load_fixed(right_img, IMG_W, IMG_H)
        if is_key_9(k):
            mode = "countdown"
            countdown_timer = COUNTDOWN
            countdown_start = cv2.getTickCount()
            post_countdown_mode = "review"

    elif mode == "countdown":
        elapsed = (cv2.getTickCount() - countdown_start) / cv2.getTickFrequency()
        if elapsed >= 1:
            countdown_timer -= 1
            countdown_start = cv2.getTickCount()
        screen[CAM_Y:CAM_Y + CAM_H, CAM_X:CAM_X + CAM_W] = cam_view
        screen[IMG_Y:IMG_Y + IMG_H, IMG_X:IMG_X + IMG_W] = load_fixed(right_img, IMG_W, IMG_H)
        if countdown_timer > 0:
            cv2.putText(screen, str(countdown_timer), (960, 540),
                        cv2.FONT_HERSHEY_SIMPLEX, 8, (255, 255, 255), 12, cv2.LINE_AA)
        else:
            raw_cap = center_crop_to_aspect(frame, 3, 2)
            bordered = add_border_and_text(raw_cap)
            last_cap_path = save_capture(bordered)
            right_img = IMG2_PATH if post_countdown_mode == "review" else IMG1_PATH
            mode = post_countdown_mode

    elif mode == "review":
        screen[CAM_Y:CAM_Y + CAM_H, CAM_X:CAM_X + CAM_W] = load_fixed(last_cap_path, CAM_W, CAM_H)
        screen[IMG_Y:IMG_Y + IMG_H, IMG_X:IMG_X + IMG_W] = load_fixed(right_img, IMG_W, IMG_H)

        if is_key_1(k):  # print
            p = make_a5(last_cap_path)
            if p:
                print_file(p)
            right_img = IMG1_PATH
            mode = "main"
        elif is_key_2(k):  # retake
            mode = "countdown"
            countdown_timer = COUNTDOWN
            countdown_start = cv2.getTickCount()
            post_countdown_mode = "review"
        elif is_key_3(k):  # discard
            right_img = IMG1_PATH
            mode = "main"

    cv2.imshow("Photo Booth", screen)
    if (k & 0xFF) == 27:  # ESC
        break

cap.release()
cv2.destroyAllWindows()
