from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, Response
from flask_cors import CORS
import cv2
import mediapipe as mp
import numpy as np
import json
from inference_sdk import InferenceHTTPClient
import time
import os
import base64
import tempfile
from io import BytesIO
from PIL import Image
import requests
from flask import current_app
from datetime import datetime
from transformers import pipeline
from ultralytics import YOLO

app = Flask(__name__)

CORS(app)

# Constants
REAL_CARD_WIDTH_MM = 85.6
MM_TO_INCH = 0.0393701

# Models
depth_pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
yolo_model = YOLO("yolo/my_model.pt")

# Google reCAPTCHA Secret Key
app.config['RECAPTCHA_SECRET_KEY'] = '6LdjVmgoAAAAAPmmWqOfoISokqReHB1joD8Ygf4_'

def verify_recaptcha(token, action='ipd_checker', threshold=0.5):
    secret_key = os.getenv('6LdjVmgoAAAAAPmmWqOfoISokqReHB1joD8Ygf4_')
    url = 'https://www.google.com/recaptcha/api/siteverify'
    data = {'secret': secret_key, 'response': token}

    try:
        response = requests.post(url, data=data)
        result = response.json()
        print("Google reCAPTCHA response:", result)

        if not result.get('success'):
            return False, f"Invalid reCAPTCHA: {result.get('error-codes')}"

        if result.get('action') != action:
            return False, f"Action mismatch: expected '{action}', got '{result.get('action')}'"

        if result.get('score', 0) < threshold:
            return False, f"Low score: {result.get('score')}"

        return True, "reCAPTCHA valid"

    except Exception as e:
        return False, f"reCAPTCHA error: {str(e)}"

# ---------- error messages ----------
ERROR_MSGS = {
    "EYES_NOT_DETECTED": "Eyes not detected – Try again with better lighting and keep both eyes visible.",
    "CARD_NOT_DETECTED": "Card not detected – Hold a standard card fully in frame, flat, and in good light.",
    "IPD_TOO_LOW": "Invalid IPD, try again, IPD measured less than 50mm.",
    "IPD_TOO_HIGH": "Invalid IPD, try again, IPD measured greater than 80mm.",
    "UNREALISTIC_IPD": "Center your face, keep eyes open, and ensure the whole mag strip is clearly visible.",
    "GENERAL": "Something went wrong. Please retake the photo in good lighting, keep your face centered, and hold the card flat with the whole magnetic strip clearly visible.",
}

def make_err(code: str, extra: str = "") -> dict:
    return {
        "success": False,
        "error_code": code,
        "error": ERROR_MSGS.get(code, ERROR_MSGS["GENERAL"]),
        "details": extra,
    }

def ipd_is_unrealistic(ipd_mm: float) -> bool:
    try:
        return (ipd_mm < 50.0) or (ipd_mm > 80.0)
    except Exception:
        return True
    

def process_image_from_bytes(image_bytes):
    """Process image from bytes and return all the measurements - exact logic from main.py
       NOTE: Uses standardized error messages. If helpers aren't defined globally, fallbacks are created."""
    
    ERROR_MSGS_LOCAL = {
        "EYES_NOT_DETECTED": "Eyes not detected – Try again with better lighting and keep both eyes visible.",
        "CARD_NOT_DETECTED": "Card not detected – Hold a standard card fully in frame, flat, and in good light.",
        "IPD_TOO_LOW": "Invalid IPD, try again, IPD measured less than 50mm.",
        "IPD_TOO_HIGH": "Invalid IPD, try again, IPD measured greater than 80mm.",
        "UNREALISTIC_IPD": "Center your face, keep eyes open, and ensure the whole mag strip is clearly visible.",
        "GENERAL": "Something went wrong. Please retake the photo in good lighting, keep your face centered, and hold the card flat with the whole magnetic strip clearly visible.",
    }

    def _make_err(code: str, extra: str = "") -> dict:
        try:
            return make_err(code, extra)
        except Exception:
            return {
                "success": False,
                "error_code": code,
                "error": ERROR_MSGS_LOCAL.get(code, ERROR_MSGS_LOCAL["GENERAL"]),
                "details": extra,
            }

    def _ipd_is_unrealistic(ipd_mm: float) -> tuple:
        try:
            if np.isnan(ipd_mm):
                return True, "UNREALISTIC_IPD"
            elif ipd_mm < 50.0:
                return True, "IPD_TOO_LOW"
            elif ipd_mm > 80.0:
                return True, "IPD_TOO_HIGH"
            else:
                return False, None
        except Exception:
            return True, "UNREALISTIC_IPD"

    try:
        # Convert bytes to numpy array
        nparr = np.frombuffer(image_bytes, np.uint8)
        image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        if image is None:
            return _make_err("GENERAL", "Could not decode image")

        h, w, _ = image.shape
        rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Initialize variables
        mp_face_mesh = mp.solutions.face_mesh
        ipd_px = 0.0
        left_center = np.array([0.0, 0.0], dtype=float)
        right_center = np.array([0.0, 0.0], dtype=float)
        nose_bridge_top_pt = None
        face_landmarks_list_for_drawing = None

        # MediaPipe FaceMesh
        facemesh_start = time.time()
        with mp_face_mesh.FaceMesh(
            static_image_mode=True, refine_landmarks=True, max_num_faces=1, min_detection_confidence=0.5
        ) as face_mesh:
            results = face_mesh.process(rgb_image)
            if not results.multi_face_landmarks:
                return _make_err("EYES_NOT_DETECTED")

            face_landmarks_from_mp = results.multi_face_landmarks[0].landmark
            face_landmarks_list_for_drawing = face_landmarks_from_mp

            if len(face_landmarks_from_mp) < 474:
                return _make_err("EYES_NOT_DETECTED", "insufficient landmarks")

            left_center = np.array(
                [face_landmarks_from_mp[473].x * w, face_landmarks_from_mp[473].y * h], dtype=float
            )
            right_center = np.array(
                [face_landmarks_from_mp[468].x * w, face_landmarks_from_mp[468].y * h], dtype=float
            )

            print('-----------------------------------------')
            print("Left Center:", left_center)
            print("Right Center:", right_center)

            ipd_px = float(np.linalg.norm(left_center - right_center))

            if len(face_landmarks_from_mp) > 168:
                nose_bridge_landmark = face_landmarks_from_mp[168]
                nose_bridge_top_pt = np.array([nose_bridge_landmark.x * w, nose_bridge_landmark.y * h], dtype=float)

        facemesh_time = time.time() - facemesh_start

        # ------- Roboflow Credit Card Detection -------
        CLIENT = InferenceHTTPClient(api_url="https://serverless.roboflow.com", api_key="tg7xheRe5OzaW3ddGyPs")
        YOUR_CARD_CLASS_NAME = "Mag-Strip-OxwU"

        # Create temporary file for Roboflow API
        with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
            temp_path = temp_file.name
            cv2.imwrite(temp_path, image)

        try:
            roboflow_start = time.time()
            roboflow_response = CLIENT.run_workflow(
                workspace_name="lumadent",
                workflow_id="card-detection-workflow",
                images={"image": temp_path},
                use_cache=True
            )
            roboflow_time = time.time() - roboflow_start
            print("=== ROBOFLOW API CALLED ===")
            try:
                print("Roboflow response:\n", json.dumps(roboflow_response, indent=2))
            except (TypeError, ValueError) as e:
                print(f"Roboflow response (raw): {roboflow_response}")
            print(f"Roboflow API processing time: {roboflow_time:.4f}s")
            print("=== END ROBOFLOW RESPONSE ===")
            print("Roboflow API called: True")
            
        except Exception as e:
            print(f"=== ROBOFLOW API FAILED ===")
            print(f"Error: {e}")
            print("=== END ROBOFLOW ERROR ===")
            
            # Return appropriate error when API fails
            return _make_err("GENERAL", f"Card detection service unavailable: {str(e)}")
            
        finally:
            try:
                os.unlink(temp_path)
            except Exception:
                pass

        def extract_predictions(obj):
            if isinstance(obj, dict):
                # Roboflow nested outputs often carry 'class' & 'confidence' at leaves
                if 'class' in obj and 'confidence' in obj:
                    return [obj]
                out = []
                for v in obj.values():
                    out.extend(extract_predictions(v))
                return out
            elif isinstance(obj, list):
                out = []
                for item in obj:
                    out.extend(extract_predictions(item))
                return out
            return []

        detections = extract_predictions(roboflow_response)

        # Find Card Detection
        card_detection = next(
            (det for det in detections
             if det.get('class') == YOUR_CARD_CLASS_NAME and float(det.get('confidence', 0)) > 0.5),
            None
        )
        if not card_detection:
            return _make_err("CARD_NOT_DETECTED")

        # ------- Width Detection from Points -------
        warning_messages = []
        final_ref_pt1, final_ref_pt2 = None, None
        card_px_width = 0.0
        draw_top_edge_pts = draw_bottom_edge_pts = None
        draw_left_edge_pts = draw_right_edge_pts = None

        points = card_detection.get('points', [])
        if not points or len(points) < 4:
            return _make_err("CARD_NOT_DETECTED", "insufficient polygon points")

        pts_array = np.array([[float(p['x']), float(p['y'])] for p in points], dtype=float)
        min_x = np.min(pts_array[:, 0])
        max_x = np.max(pts_array[:, 0])

        left_points = pts_array[pts_array[:, 0] == min_x]
        right_points = pts_array[pts_array[:, 0] == max_x]
        if len(left_points) == 0 or len(right_points) == 0:
            return _make_err("CARD_NOT_DETECTED", "edge extraction failed")

        left_top    = left_points[np.argmin(left_points[:, 1])]
        left_bottom = left_points[np.argmax(left_points[:, 1])]
        right_top   = right_points[np.argmin(right_points[:, 1])]
        right_bottom= right_points[np.argmax(right_points[:, 1])]

        left_top_pt     = tuple(left_top.astype(int))
        left_bottom_pt  = tuple(left_bottom.astype(int))
        right_top_pt    = tuple(right_top.astype(int))
        right_bottom_pt = tuple(right_bottom.astype(int))

        # Calculate angle of the bottom edge with respect to horizontal
        bottom_vec = np.array(right_bottom_pt, dtype=float) - np.array(left_bottom_pt, dtype=float)
        if np.linalg.norm(bottom_vec) == 0:
            return _make_err("CARD_NOT_DETECTED", "degenerate bottom edge")
        horizontal_vec = np.array([1.0, 0.0], dtype=float)
        cos_theta = np.dot(bottom_vec, horizontal_vec) / (np.linalg.norm(bottom_vec) * np.linalg.norm(horizontal_vec))
        angle_rad = np.arccos(np.clip(cos_theta, -1.0, 1.0))
        angle_deg = float(np.degrees(angle_rad))

        if 5 < angle_deg < 175:
            # warning_messages.append("Card not level; try to keep the strip horizontal.")
            warning_messages.append("Something went wrong. Please retake the photo in good lighting, keep your face centered, and hold the card flat with the whole magnetic strip clearly visible.")

        # For width, use the top edge
        card_px_width = float(np.linalg.norm(np.array(right_top_pt, dtype=float) - np.array(left_top_pt, dtype=float)))
        if card_px_width <= 0:
            return _make_err("CARD_NOT_DETECTED", "zero card width")

        final_ref_pt1, final_ref_pt2 = left_top_pt, right_top_pt

        draw_top_edge_pts = (left_top_pt, right_top_pt)
        draw_bottom_edge_pts = (left_bottom_pt, right_bottom_pt)
        draw_left_edge_pts = (left_top_pt, left_bottom_pt)
        draw_right_edge_pts = (right_top_pt, right_bottom_pt)

        # ------- IPD Conversion -------
        card_mm_width = 85.60
        px_per_mm = (card_px_width / card_mm_width) if card_px_width else 0.0
        if px_per_mm <= 0:
            return _make_err("CARD_NOT_DETECTED", "invalid scale (px_per_mm)")

        ipd_mm = float(ipd_px / px_per_mm) if px_per_mm > 0 else 0.0

        # Extract Roboflow points array for pseudo blue point (fallback to midpoint of eyes)
        rf_points = card_detection.get('points', [])
        if isinstance(rf_points, list) and len(rf_points) > 44 and 'x' in rf_points[44] and 'y' in rf_points[44]:
            blue_x = float(rf_points[44]['x'])
            blue_y = float(rf_points[44]['y'])
        else:
            blue_x = float((left_center[0] + right_center[0]) / 2.0)
            blue_y = float((left_center[1] + right_center[1]) / 2.0)
        blue_point = np.array([blue_x, blue_y], dtype=float)

        # Calculate distances
        red_to_blue_mm = float(np.linalg.norm(left_center - blue_point) / px_per_mm)
        blue_to_green_mm = float(np.linalg.norm(blue_point - right_center) / px_per_mm)

        # Now calculate true IPD segments from actual Blue (uses your helper)
        try:
            red_to_blue_mm, blue_to_green_mm, total_ipd_mm = calculate_true_ipd_segments(
                left_center, blue_point, right_center, px_per_mm
            )
        except Exception:
            # Fallback: keep previously computed values
            total_ipd_mm = ipd_mm

        print("================================================================")
        print(" ")
        print("IPD (MM): ", ipd_mm * 1.028)
        print(f"IPD (MM): {ipd_mm:.2f}")
        print(" ")
        print("================================================================")

        # Plausibility checks
        if (
            np.allclose(left_center, [0.0, 0.0]) or
            np.allclose(right_center, [0.0, 0.0]) or
            ipd_mm == 0.0
        ):
            return _make_err("EYES_NOT_DETECTED")

        is_unrealistic, error_code = _ipd_is_unrealistic(ipd_mm)
        if is_unrealistic:
            return _make_err(error_code)

        left_offset_mm = right_offset_mm = 0.0
        if nose_bridge_top_pt is not None and px_per_mm > 0:
            left_offset_mm = float(abs(left_center[0] - nose_bridge_top_pt[0]) / px_per_mm)
            right_offset_mm = float(abs(right_center[0] - nose_bridge_top_pt[0]) / px_per_mm)

        card_px_width_top = float(np.linalg.norm(np.array(right_top_pt, dtype=float) - np.array(left_top_pt, dtype=float)))
        card_px_width_bottom = float(np.linalg.norm(np.array(right_bottom_pt, dtype=float) - np.array(left_bottom_pt, dtype=float)))

        card_mm_width_top = (card_px_width_top / px_per_mm) if px_per_mm else 0.0
        card_mm_width_bottom = (card_px_width_bottom / px_per_mm) if px_per_mm else 0.0

        if abs(card_mm_width_top - card_mm_width_bottom) > 3.0:
            warning_messages.append("Width of the card is not consistent; try holding it more parallel to the camera.")

        # Calculate all edge angles
        def edge_angle(pt1, pt2):
            vec = np.array(pt2, dtype=float) - np.array(pt1, dtype=float)
            hv = np.array([1.0, 0.0], dtype=float)
            if np.linalg.norm(vec) == 0:
                return 0.0
            cos_th = np.dot(vec, hv) / (np.linalg.norm(vec) * np.linalg.norm(hv))
            ang = float(np.degrees(np.arccos(np.clip(cos_th, -1.0, 1.0))))
            if vec[1] < 0:
                ang = -ang
            return ang

        angle_top = edge_angle(left_top_pt, right_top_pt)
        angle_bottom = edge_angle(left_bottom_pt, right_bottom_pt)
        angle_left = edge_angle(left_top_pt, left_bottom_pt)
        angle_right = edge_angle(right_top_pt, right_bottom_pt)

        # Create annotated image
        annotated_image = image.copy()
        cv2.circle(annotated_image, tuple(left_center.astype(int)), 5, (0, 255, 0), -1)
        cv2.circle(annotated_image, tuple(right_center.astype(int)), 5, (0, 0, 255), -1)
        cv2.line(annotated_image, tuple(left_center.astype(int)), tuple(right_center.astype(int)), (255, 0, 0), 2)
        if nose_bridge_top_pt is not None:
            cv2.circle(annotated_image, tuple(nose_bridge_top_pt.astype(int)), 5, (255, 255, 0), -1)
        if final_ref_pt1 and final_ref_pt2:
            cv2.line(annotated_image, final_ref_pt1, final_ref_pt2, (0, 255, 255), 3)
        if draw_top_edge_pts:
            cv2.line(annotated_image, draw_top_edge_pts[0], draw_top_edge_pts[1], (255, 165, 0), 2)
        if draw_bottom_edge_pts:
            cv2.line(annotated_image, draw_bottom_edge_pts[0], draw_bottom_edge_pts[1], (200, 0, 200), 2)
        if draw_left_edge_pts:
            cv2.line(annotated_image, draw_left_edge_pts[0], draw_left_edge_pts[1], (0, 255, 255), 2)
        if draw_right_edge_pts:
            cv2.line(annotated_image, draw_right_edge_pts[0], draw_right_edge_pts[1], (0, 255, 255), 2)

        # Convert image to base64 for web display
        _, buffer = cv2.imencode('.jpg', annotated_image)
        img_base64 = base64.b64encode(buffer).decode('utf-8')

        print("Annotated base64 image:")
        print(img_base64[:500])

        return {
            'success': True,
            'annotated_image': img_base64,
            'left_center': f"{red_to_blue_mm:.2f}",
            'right_center': f"{blue_to_green_mm:.2f}",
            'chosen_card_width': f"{card_px_width:.2f}px",
            'scale': f"{px_per_mm:.2f} px/mm",
            'ipd_px': f"{ipd_px:.2f}",
            'ipd_mm': f"{ipd_mm:.2f}",
            'left_offset': f"{left_offset_mm:.2f} mm",
            'right_offset': f"{right_offset_mm:.2f} mm",
            'angle_top': f"{angle_top:.2f} degrees",
            'angle_bottom': f"{angle_bottom:.2f} degrees",
            'angle_left': f"{angle_left:.2f} degrees",
            'angle_right': f"{angle_right:.2f} degrees",
            'using_top_edge_width': f"{card_px_width:.2f}px",
            'nose_bridge_location': f"({nose_bridge_top_pt[0]:.0f}, {nose_bridge_top_pt[1]:.0f})" if nose_bridge_top_pt is not None else "Not detected",
            'facemesh_time': f"{facemesh_time:.4f}s",
            'roboflow_time': f"{roboflow_time:.4f}s",
            'roboflow_response': roboflow_response,
            'warnings': warning_messages
        }

    except Exception as e:
        return _make_err("GENERAL", str(e))


def process_distance_from_bytes(image_bytes):
    """Calculate face-to-camera distance using IPD and card reference, with standardized errors."""
    
    ERROR_MSGS_LOCAL = {
        "EYES_NOT_DETECTED": "Eyes not detected – Try again with better lighting and keep both eyes visible.",
        "CARD_NOT_DETECTED": "Card not detected – Hold a standard card fully in frame, flat, and in good light.",
        "IPD_TOO_LOW": "Invalid IPD, try again, IPD measured less than 50mm.",
        "IPD_TOO_HIGH": "Invalid IPD, try again, IPD measured greater than 80mm.",
        "UNREALISTIC_IPD": "Center your face, keep eyes open, and ensure the whole mag strip is clearly visible.",
        "GENERAL": "Something went wrong. Please retake the photo in good lighting, keep your face centered, and hold the card flat with the whole magnetic strip clearly visible.",
    }

    def _make_err(code: str, extra: str = "") -> dict:
        try:
            return make_err(code, extra)
        except Exception:
            return {
                "success": False,
                "error_code": code,
                "error": ERROR_MSGS_LOCAL.get(code, ERROR_MSGS_LOCAL["GENERAL"]),
                "details": extra,
                "distance": None,
            }

    try:
        nparr = np.frombuffer(image_bytes, np.uint8)
        frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        if frame is None:
            return _make_err("GENERAL", "Could not decode image")

        h, w, _ = frame.shape
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        mp_face_mesh = mp.solutions.face_mesh
        with mp_face_mesh.FaceMesh(
            static_image_mode=True, refine_landmarks=True, max_num_faces=1, min_detection_confidence=0.5
        ) as face_mesh:
            results = face_mesh.process(rgb)
            if not results.multi_face_landmarks:
                return _make_err("EYES_NOT_DETECTED")

            landmarks = results.multi_face_landmarks[0].landmark
            if len(landmarks) < 474:  
                return _make_err("EYES_NOT_DETECTED", "insufficient landmarks")

            left_eye = np.array([landmarks[473].x * w, landmarks[473].y * h], dtype=float)
            right_eye = np.array([landmarks[468].x * w, landmarks[468].y * h], dtype=float)
            ipd_px = float(np.linalg.norm(left_eye - right_eye))
            if ipd_px <= 0 or np.isnan(ipd_px):
                return _make_err("EYES_NOT_DETECTED", "ipd_px invalid")

        # ---- Roboflow card detection ----
        CLIENT = InferenceHTTPClient(api_url="https://serverless.roboflow.com", api_key="tg7xheRe5OzaW3ddGyPs")
        with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tf:
            temp_path = tf.name
            cv2.imwrite(temp_path, frame)

        try:
            rf_response = CLIENT.run_workflow(
                workspace_name="lumadent",
                workflow_id="card-detection-workflow",
                images={"image": temp_path},
                use_cache=True
            )
            
        finally:
            try:
                os.unlink(temp_path)
            except Exception:
                pass

        def extract_predictions(obj):
            if isinstance(obj, dict):
                if 'class' in obj and 'confidence' in obj:
                    return [obj]
                out = []
                for v in obj.values():
                    out.extend(extract_predictions(v))
                return out
            elif isinstance(obj, list):
                out = []
                for item in obj:
                    out.extend(extract_predictions(item))
                return out
            return []

        detections = extract_predictions(rf_response)
        card_detection = next(
            (det for det in detections if det.get('class') == "Mag-Strip-OxwU" and float(det.get('confidence', 0)) > 0.5),
            None
        )
        if not card_detection:
            return _make_err("CARD_NOT_DETECTED")

        points = card_detection.get('points', [])
        if not points or len(points) < 4:
            return _make_err("CARD_NOT_DETECTED", "insufficient polygon points")

        pts_array = np.array([[float(p['x']), float(p['y'])] for p in points], dtype=float)
        min_x = np.min(pts_array[:, 0])
        max_x = np.max(pts_array[:, 0])

        left_points = pts_array[pts_array[:, 0] == min_x]
        right_points = pts_array[pts_array[:, 0] == max_x]
        if len(left_points) == 0 or len(right_points) == 0:
            return _make_err("CARD_NOT_DETECTED", "edge extraction failed")

        left_top = left_points[np.argmin(left_points[:, 1])]
        right_top = right_points[np.argmin(right_points[:, 1])]
        card_px_width = float(np.linalg.norm(np.array(right_top, dtype=float) - np.array(left_top, dtype=float)))
        if card_px_width <= 0 or np.isnan(card_px_width):
            return _make_err("CARD_NOT_DETECTED", "zero/invalid card width")

        # ---- Distance via triangle similarity / focal length estimate ----
        real_ipd_mm = 63.0
        card_real_width_mm = 85.6
        known_card_distance_mm = 350.0  
        focal_length_px = (card_px_width * known_card_distance_mm) / card_real_width_mm
        if focal_length_px <= 0 or np.isnan(focal_length_px):
            return _make_err("GENERAL", "invalid focal length computed")

        distance_mm = (real_ipd_mm * focal_length_px) / ipd_px
        if np.isnan(distance_mm) or distance_mm <= 0:
            return _make_err("GENERAL", "invalid distance computed")

        distance_inch = float(distance_mm / 25.4)
        # clamp to a reasonable range
        distance_inch = min(max(distance_inch, 5.0), 100.0)

        return {'success': True, 'distance': round(distance_inch, 1)}

    except Exception as e:
        return _make_err("GENERAL", str(e))



def calculate_true_ipd_segments(left_eye, blue_point, right_eye, px_per_mm):
    red_to_blue_mm = np.linalg.norm(left_eye - blue_point) / px_per_mm
    blue_to_green_mm = np.linalg.norm(blue_point - right_eye) / px_per_mm
    total_mm = np.linalg.norm(left_eye - right_eye) / px_per_mm

    print("================= Eye to Center =================")
    print("Left Eye to Center: ", red_to_blue_mm)
    print("Right Eye to Center: ", blue_to_green_mm)
    print("IPD MM: ", total_mm)
    print("====================================================")
    
    return red_to_blue_mm, blue_to_green_mm, total_mm


def create_face_mask(image_bytes):
    """Build annotated overlay + quick depth/IPD readouts.
       Returns: (data_dict_or_None, error_message_or_None, error_code_or_None)
    """
 
    ERROR_MSGS_LOCAL = {
        "EYES_NOT_DETECTED": "Eyes not detected – Try again with better lighting and keep both eyes visible.",
        "CARD_NOT_DETECTED": "Card not detected – Hold a standard card fully in frame, flat, and in good light.",
        "IPD_TOO_LOW": "Invalid IPD, try again, IPD measured less than 50mm.",
        "IPD_TOO_HIGH": "Invalid IPD, try again, IPD measured greater than 80mm.",
        "UNREALISTIC_IPD": "Center your face, keep eyes open, and ensure the whole mag strip is clearly visible.",
        "GENERAL": "Something went wrong. Please retake the photo in good lighting, keep your face centered, and hold the card flat with the whole magnetic strip clearly visible.",
    }

    def _err_tuple(code: str, extra: str = None):
        try:
            # If a global map exists, prefer it
            msg = ERROR_MSGS.get(code, ERROR_MSGS.get("GENERAL"))
        except Exception:
            msg = ERROR_MSGS_LOCAL.get(code, ERROR_MSGS_LOCAL["GENERAL"])
        return None, msg if not extra else msg, code

    try:
        # ---- Decode image ----
        nparr = np.frombuffer(image_bytes, np.uint8)
        image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        if image is None:
            return _err_tuple("GENERAL", "Could not decode image")

        h, w, _ = image.shape
        rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mesh_image = image.copy()

        warnings = []

        # ---------- Step 1: YOLO for card/strip detection ----------
        results = yolo_model(image)
        boxes = results[0].boxes
        names = results[0].names

        # Utility to normalize class names (lower, strip underscores/plurals)
        def norm(n):
            return str(n).strip().lower().replace("_", " ").replace("-", " ")

        # Acceptable aliases for the mag strip/card class
        strip_aliases = {
            "magnetic strip", "magnetic strips", "mag strip oxwu",
            "strip", "mag strip", "credit card", "card"
        }

        card_box = None
        for box in boxes:
            cls_id = int(box.cls[0])
            name = norm(names.get(cls_id, str(cls_id)))
            if name in strip_aliases:
                card_box = box.xyxy[0].cpu().numpy()
                break

        # If no card/strip → standardized CARD_NOT_DETECTED error
        if card_box is None:
            return _err_tuple("CARD_NOT_DETECTED")

        x1, y1, x2, y2 = map(int, card_box)
        card_width_px = max(1, x2 - x1)
        px_per_mm = card_width_px / float(REAL_CARD_WIDTH_MM)
        if px_per_mm <= 0:
            return _err_tuple("CARD_NOT_DETECTED")
        cv2.rectangle(mesh_image, (x1, y1), (x2, y2), (0, 255, 0), 2)

        # ---------- Step 2: Depth estimation ----------
        pil_image = Image.fromarray(rgb_image)
        depth_output = depth_pipe(pil_image)
        depth_map = np.array(depth_output["depth"])

        depth_norm = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        depth_colored = cv2.applyColorMap(depth_norm, cv2.COLORMAP_INFERNO)  # (not drawn, but kept for future usage)

        # ---------- Step 3: FaceMesh for eyes ----------
        mp_face_mesh = mp.solutions.face_mesh
        with mp_face_mesh.FaceMesh(
            static_image_mode=True, refine_landmarks=True, max_num_faces=1, min_detection_confidence=0.3
        ) as face_mesh:
            results = face_mesh.process(rgb_image)
            if not results.multi_face_landmarks:
                return _err_tuple("EYES_NOT_DETECTED")

            lm = results.multi_face_landmarks[0].landmark
            if len(lm) <= 473:
                return _err_tuple("EYES_NOT_DETECTED")

            # Iris centers (MediaPipe Iris)
            left_eye_lm  = lm[468]  # right iris center in some refs; keep mapping consistent across your codebase
            right_eye_lm = lm[473]  # left iris center
            center_lm    = lm[6]    # face center-ish landmark

            # Convert to pixel coords
            left_px   = (int(left_eye_lm.x * w),  int(left_eye_lm.y * h))
            right_px  = (int(right_eye_lm.x * w), int(right_eye_lm.y * h))
            center_px = (int(center_lm.x * w),    int(center_lm.y * h))

            # Warn if near edges
            for label, (x, y) in (("left", left_px), ("right", right_px)):
                if x <= 1 or x >= (w - 1) or y <= 1 or y >= (h - 1):
                    warnings.append(f"{label.capitalize()} iris near image edge; results may be off.")

            # Draw overlays
            cv2.circle(mesh_image, left_px,  5, (0, 0, 255),  -1)
            cv2.circle(mesh_image, right_px, 5, (0, 0, 255),  -1)
            cv2.circle(mesh_image, center_px,5, (255, 0, 255),-1)
            cv2.line(mesh_image, left_px, right_px, (0, 255, 255), 2)
            # cv2.line(mesh_image, (center_px[0], 0), (center_px[0], h), (255, 255, 0), 2)

            # Region for depth near eyes
            x_min = max(0, min(left_px[0], right_px[0]) - 30)
            x_max = min(w, max(left_px[0], right_px[0]) + 30)
            y_min = max(0, min(left_px[1], right_px[1]) - 30)
            y_max = min(h, max(left_px[1], right_px[1]) + 30)

            roi_depth = depth_map[y_min:y_max, x_min:x_max]
            mean_depth_mm = float(np.nan_to_num(np.mean(roi_depth), nan=0.0) * 1000.0)
            mean_depth_in = round(mean_depth_mm * MM_TO_INCH, 2)
            mean_depth_mm = round(mean_depth_mm, 2)

            # IPD in px and mm (scale from card)
            ipd_px = float(np.hypot(right_px[0] - left_px[0], right_px[1] - left_px[1]))
            ipd_mm = float(ipd_px / px_per_mm) if px_per_mm > 0 else 0.0

            left_dist_mm  = float(abs(left_px[0]  - center_px[0]) / px_per_mm) if px_per_mm > 0 else 0.0
            right_dist_mm = float(abs(right_px[0] - center_px[0]) / px_per_mm) if px_per_mm > 0 else 0.0

        # Encode result image
        ok, buffer = cv2.imencode('.jpg', mesh_image)
        if not ok:
            return _err_tuple("GENERAL", "Failed to encode result image")
        base64_img = base64.b64encode(buffer).decode('utf-8')

        return {
            "image": base64_img,
            "depth_mm": mean_depth_mm,
            "depth_inches": mean_depth_in,
            "ipd_mm": round(ipd_mm, 2),
            "left_mm": round(left_dist_mm, 2),
            "right_mm": round(right_dist_mm, 2),
            "warnings": warnings
        }, None, None

    except Exception as e:
        return _err_tuple("GENERAL", str(e))



# ============= API ROUTES =============

@app.route('/')
def index():
    return jsonify({
        "status": True,
        "message": "IPD Measurement API Started"
    })

@app.route('/api/ipd-checker', methods=['POST'])
def api_analyze_all():
    """
    Unified IPD + distance endpoint with standardized error messages:
      - Eyes not detected  -> "Eyes not detected – Try again with better lighting and keep both eyes visible."
      - Card not detected  -> "Card not detected – Hold a standard card fully in frame, flat, and in good light."
      - IPD TOO Low        -> "Invalid IPD, try again, IPD measured less than 50mm."
      - IPD TOO HIGH       -> "Invalid IPD, try again, IPD measured greater than 80mm."
      - Unrealistic IPD    -> "Center your face, keep eyes open, and ensure the whole mag strip is clearly visible."
      - General error      -> "Something went wrong. Please retake the photo in good lighting, keep your face centered, and hold the card flat with the whole magnetic strip clearly visible."
    """

    ERROR_MSGS_LOCAL = {
        "EYES_NOT_DETECTED": "Eyes not detected – Try again with better lighting and keep both eyes visible.",
        "CARD_NOT_DETECTED": "Card not detected – Hold a standard card fully in frame, flat, and in good light.",
        "IPD_TOO_LOW": "Invalid IPD, try again, IPD measured less than 50mm.",
        "IPD_TOO_HIGH": "Invalid IPD, try again, IPD measured greater than 80mm.",
        "UNREALISTIC_IPD": "Center your face, keep eyes open, and ensure the whole mag strip is clearly visible.",
        "GENERAL": "Something went wrong. Please retake the photo in good lighting, keep your face centered, and hold the card flat with the whole magnetic strip clearly visible.",
        "MISSING_IMAGE": "No image uploaded",
        "MISSING_RECAPTCHA": "Missing reCAPTCHA token",
        "RECAPTCHA_FAIL": "reCAPTCHA verification failed",
    }

    def _msg(key: str) -> str:
        try:
            return ERROR_MSGS.get(key, ERROR_MSGS_LOCAL.get(key, ERROR_MSGS_LOCAL["GENERAL"]))
        except Exception:
            return ERROR_MSGS_LOCAL.get(key, ERROR_MSGS_LOCAL["GENERAL"])

    try:
        token = request.form.get('g_recaptcha_response') or request.headers.get('X-Recaptcha-Token')
        print(f"reCAPTCHA Token: {token}")

        if not token:
            return jsonify({
                "status": False,
                "data": {
                    "message": _msg("MISSING_RECAPTCHA"),
                    "error": "reCAPTCHA token required",
                    "ipd_mm": None,
                    "distance": None,
                    "warnings": []
                }
            }), 400

        if request.remote_addr.startswith("192.168.") or request.remote_addr == "127.0.0.1":
            print("Bypassing reCAPTCHA for local testing from:", request.remote_addr)
            is_valid = True
        else:
            is_valid, message = verify_recaptcha(token, action='ipd_checker', threshold=0.5)

        if not is_valid:
            return jsonify({
                "status": False,
                "data": {
                    "message": _msg("RECAPTCHA_FAIL"),
                    "error": _msg("RECAPTCHA_FAIL"),
                    "ipd_mm": None,
                    "distance": None,
                    "warnings": []
                }
            }), 403

        if 'image' not in request.files or request.files['image'].filename == '':
            return jsonify({
                "status": False,
                "data": {
                    "message": _msg("MISSING_IMAGE"),
                    "error": "Image file missing or empty",
                    "ipd_mm": None,
                    "distance": None,
                    "warnings": []
                }
            }), 400

        image_bytes = request.files['image'].read()

        ipd_data = process_image_from_bytes(image_bytes)
        distance_data = process_distance_from_bytes(image_bytes)
        face_mask_data, face_error, face_error_code = create_face_mask(image_bytes)

        if face_error:
            return jsonify({
                "status": False,
                "data": {
                    "message": face_error,
                    "error": face_error,
                    "ipd_mm": None,
                    "distance": None,
                    "annotated_image": ipd_data.get('annotated_image'),
                    "warnings": []
                }
            }), 422

        if not ipd_data.get('success', False):
            err_text = ipd_data.get('error') or _msg("GENERAL")
            code = ipd_data.get('error_code')
            if code in ("EYES_NOT_DETECTED", "CARD_NOT_DETECTED", "UNREALISTIC_IPD", "GENERAL"):
                err_text = _msg(code)
            return jsonify({
                "status": False,
                "data": {
                    "message": err_text,
                    "error": err_text,
                    "ipd_mm": None,
                    "annotated_image": ipd_data.get('annotated_image'),
                    "distance": distance_data.get('distance') if distance_data.get('success') else None,
                    "warnings": []
                }
            }), 422

        left_mm = face_mask_data.get("left_mm")
        right_mm = face_mask_data.get("right_mm")
        new_ipd = face_mask_data.get("ipd_mm")

        if not distance_data.get('success', False) and not ipd_data.get('success', False):
            return jsonify({
                "status": False,
                "data": {
                    "message": _msg("GENERAL"),
                    "error": _msg("GENERAL"),
                    "annotated_image": ipd_data.get('annotated_image'),
                    "ipd_mm": None,
                    "distance": None,
                    "warnings": []
                }
            }), 422

        warnings = ipd_data.get('warnings', []) if ipd_data.get('success') else []
        if warnings:
            return jsonify({
                "status": False,
                "data": {
                    "message": _msg("GENERAL"),
                    "error": _msg("GENERAL"),
                    "ipd_mm": None,
                    "distance": distance_data.get('distance') if distance_data.get('success') else None,
                    "left_center": ipd_data.get('left_center'),
                    "right_center": ipd_data.get('right_center'),
                    "annotated_image": ipd_data.get('annotated_image'),
                    "warnings": warnings
                }
            }), 422

        ipd_mm_value = ipd_data.get('ipd_mm')
        try:
            ipd_mm_value = float(ipd_mm_value) * 1.028

            if ipd_mm_value < 50.0:
                error_msg = "Invalid IPD, try again, IPD measured less than 50mm."
            elif ipd_mm_value > 80.0:
                error_msg = "Invalid IPD, try again, IPD measured greater than 80mm."
            else:
                error_msg = None

            if error_msg:
                return jsonify({
                    "status": False,
                    "data": {
                        "message": error_msg,
                        "error": error_msg,
                        "ipd_mm": None,
                        "annotated_image": ipd_data.get('annotated_image'),
                        "distance": distance_data.get('distance') if distance_data.get('success') else None,
                        "warnings": []
                    }
                }), 422

            ipd_mm_value = f"{ipd_mm_value:.2f}"
        except (TypeError, ValueError):
            return jsonify({
                "status": False,
                "data": {
                    "message": _msg("UNREALISTIC_IPD"),
                    "error": _msg("UNREALISTIC_IPD"),
                    "ipd_mm": None,
                    "annotated_image": ipd_data.get('annotated_image'),
                    "distance": distance_data.get('distance') if distance_data.get('success') else None,
                    "warnings": []
                }
            }), 422

        return jsonify({
            "status": True,
            "data": {
                "message": "success",
                "error": "",
                "ipd_mm": ipd_mm_value,
                "distance": distance_data.get('distance') if distance_data.get('success') else None,
                "left_center": right_mm,
                "right_center": left_mm,
                "annotated_image": ipd_data.get('annotated_image'),
                "new_left_mm": right_mm,
                "new_right_mm": left_mm,
                "new_ipd_mm": new_ipd,
                "warnings": []
            }
        })

    except Exception as e:
        return jsonify({
            "status": False,
            "data": {
                "message": _msg("GENERAL"),
                "error": _msg("GENERAL"),
                "ipd_mm": None,
                "distance": None,
                "annotated_image": ipd_data.get('annotated_image'),
                "warnings": []
            }
        }), 500

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=5000)