import io
import json
import math
import os
import shutil
from datetime import datetime, timezone
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import numpy as np
import pygrib
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
matplotlib.rcParams["path.simplify"] = True
matplotlib.rcParams["path.simplify_threshold"] = 1.0
import matplotlib.cm as cm
import matplotlib.patheffects as pe
from matplotlib.colors import BoundaryNorm, ListedColormap
from PIL import Image
from pyproj import Transformer


# ── PATHS ───────────────────────────────────────────────────────────────────
grib_file = Path(r"/home/ukwf/public_html/pro-maps/ecm9/gribs/latest.grib2")
output_directory = Path(r"/home/ukwf/public_html/pro-maps/ecm9/")
runs_directory = output_directory / "runs"

# Reassigned to the private build folder after the GRIB run time is known.
tiles_directory = output_directory / "tiles"

output_directory.mkdir(parents=True, exist_ok=True)
runs_directory.mkdir(parents=True, exist_ok=True)

CLEAN_TILES = False
MAX_FORECAST_HOUR = 360


def forecast_hour_allowed(hour):
    """ECM9km forecast schedule: 3-hourly T+0 to T+144, then 6-hourly T+150 to T+360."""
    if 0 <= hour <= 144:
        return hour % 3 == 0
    if 150 <= hour <= 360:
        return hour % 6 == 0
    return False

# Match the first working precipitation tile setup.
# The HTML can still allow free zooming, but the native generated tiles match the original look.
MIN_ZOOM = 5
MAX_ZOOM = 7
DEFAULT_ZOOM = 5

TILE_SIZE = 256
TILE_DPI = 96
TILE_PAD_M = 60_000

# Slightly brighten overlay graphics before saving tiles.
# 1.00 = no change. 1.12 gives a small lift without washing colours out.
OVERLAY_BRIGHTNESS = 1.12

# CPU parallelism for tile rendering.
# None = auto, using all available cores minus one.
# Set to a number like 4, 8, 12, etc if you want to cap CPU use.
MAX_WORKERS = None

# Faster reruns: do not regenerate tiles that already exist.
SKIP_EXISTING_TILES = False

def get_max_workers():
    if MAX_WORKERS is not None:
        return max(1, int(MAX_WORKERS))
    cores = os.cpu_count() or 2
    return max(1, min(cores, 8))



def is_requested_forecast_hour(hour):
    """ECM9 schedule: 3-hourly T+0 to T+144, then 6-hourly T+150 to T+360.

    Shorter runs, such as 06z runs ending at T+144, are handled naturally because
    only forecast hours actually present in the GRIB are processed.
    """
    try:
        hour = int(hour)
    except Exception:
        return False

    if hour < 0 or hour > MAX_FORECAST_HOUR:
        return False
    if hour <= 144:
        return hour % 3 == 0
    return hour >= 150 and hour % 6 == 0


# ── ORIGINAL / SHARED COLOUR SCHEMES ────────────────────────────────────────
PRECIP_RGB_LIST = [
    (150, 210, 250), (120, 185, 250), (80, 165, 245), (60, 150, 245),
    (40, 130, 240), (30, 110, 235), (20, 100, 210),
    (30, 180, 30),  (55, 210, 60),  (80, 240, 80),  (120, 245, 115),
    (150, 245, 140), (180, 250, 170),
    (255, 232, 120), (255, 192, 60), (255, 160, 0), (255, 96, 0),
    (255, 50, 0),  (225, 20, 0),  (192, 0, 0),  (165, 0, 0),
    (200, 60, 60), (230, 80, 80), (230, 112, 112), (230, 140, 140),
    (248, 160, 160), (255, 200, 200), (255, 230, 230),
    (219, 169, 199), (197, 129, 199), (175, 89, 199),
    (153, 49, 199), (131, 10, 199)
]
PRECIP_LEVELS = [
    0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2, 2.5,
    3, 3.5, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 25,
    30, 35, 40, 45, 50, 60, 80
]

# Pulled from your temperature/dew-point scripts in the uploaded ZIP.
TEMP_RGB_VALUES = [(250, 0, 250), (225, 0, 238), (200, 0, 226), (175, 0, 214), (150, 0, 201), (125, 0, 189), (100, 0, 177), (75, 0, 164), (50, 0, 152), (25, 0, 140), (0, 0, 127), (9, 16, 139), (19, 32, 152), (28, 48, 165), (38, 64, 178), (48, 80, 191), (57, 96, 203), (67, 112, 216), (76, 128, 229), (86, 144, 242), (96, 160, 255), (92, 154, 243), (87, 147, 230), (82, 141, 217), (77, 134, 204), (72, 128, 191), (67, 121, 179), (62, 115, 166), (57, 108, 153), (52, 102, 140), (47, 95, 127), (43, 111, 139), (38, 127, 151), (33, 143, 163), (29, 159, 176), (24, 175, 188), (19, 191, 200), (15, 207, 213), (10, 223, 225), (5, 239, 237), (0, 255, 0), (0, 241, 0), (0, 226, 0), (0, 211, 0), (0, 196, 0), (0, 181, 0), (0, 167, 0), (0, 152, 0), (0, 137, 0), (0, 122, 0), (0, 107, 0), (25, 121, 0), (51, 136, 0), (76, 151, 0), (102, 166, 0), (127, 181, 0), (153, 195, 0), (178, 210, 0), (204, 225, 0), (229, 240, 0), (255, 255, 0), (255, 230, 0), (253, 207, 0), (250, 184, 0), (248, 161, 0), (245, 138, 0), (243, 115, 0), (240, 92, 0), (238, 69, 0), (235, 46, 0), (233, 23, 0), (227, 23, 14), (220, 23, 28), (214, 23, 43), (207, 23, 57), (201, 23, 72), (194, 22, 86), (188, 22, 101), (181, 22, 115), (175, 22, 130), (168, 22, 144), (175, 41, 153), (182, 60, 162), (189, 79, 171), (196, 98, 181), (203, 117, 190), (210, 137, 199), (217, 156, 208), (224, 175, 218), (231, 194, 227), (238, 213, 236)]
DEW_RGB_VALUES = [(250, 0, 250), (225, 0, 238), (200, 0, 226), (175, 0, 214), (150, 0, 201), (125, 0, 189), (100, 0, 177), (75, 0, 164), (50, 0, 152), (25, 0, 140), (0, 0, 127), (9, 16, 139), (19, 32, 152), (28, 48, 165), (38, 64, 178), (48, 80, 191), (57, 96, 203), (67, 112, 216), (76, 128, 229), (86, 144, 242), (96, 160, 255), (92, 154, 243), (87, 147, 230), (82, 141, 217), (77, 134, 204), (72, 128, 191), (67, 121, 179), (62, 115, 166), (57, 108, 153), (52, 102, 140), (47, 95, 127), (43, 111, 139), (38, 127, 151), (33, 143, 163), (29, 159, 176), (24, 175, 188), (19, 191, 200), (15, 207, 213), (10, 223, 225), (5, 239, 237), (0, 255, 0), (0, 241, 0), (0, 226, 0), (0, 211, 0), (0, 196, 0), (0, 181, 0), (0, 167, 0), (0, 152, 0), (0, 137, 0), (0, 122, 0), (0, 107, 0), (25, 121, 0), (51, 136, 0), (76, 151, 0), (102, 166, 0), (127, 181, 0), (153, 195, 0), (178, 210, 0), (204, 225, 0), (229, 240, 0), (255, 255, 0), (255, 230, 0), (253, 207, 0), (250, 184, 0), (248, 161, 0), (245, 138, 0), (243, 115, 0), (240, 92, 0), (238, 69, 0), (235, 46, 0), (233, 23, 0), (227, 23, 14), (220, 23, 28), (214, 23, 43), (207, 23, 57), (201, 23, 72), (194, 22, 86), (188, 22, 101), (181, 22, 115), (175, 22, 130), (168, 22, 144), (175, 41, 153), (182, 60, 162), (189, 79, 171), (196, 98, 181), (203, 117, 190), (210, 137, 199), (217, 156, 208), (224, 175, 218), (231, 194, 227), (238, 213, 236)]
T850_RGB_VALUES = [(192, 0, 217), (192, 0, 217), (155, 0, 209), (155, 0, 209), (155, 0, 209), (155, 0, 209), (127, 0, 204), (127, 0, 204), (127, 0, 204), (101, 2, 207), (101, 2, 207), (86, 2, 206), (86, 2, 206), (72, 0, 207), (60, 0, 215), (50, 0, 221), (30, 0, 230), (6, 0, 239), (2, 40, 245), (0, 76, 253), (0, 96, 251), (0, 112, 249), (0, 143, 251), (0, 169, 253), (6, 189, 255), (10, 207, 255), (8, 203, 195), (6, 201, 143), (6, 205, 116), (6, 211, 94), (42, 223, 0), (80, 227, 0), (112, 231, 0), (147, 239, 0), (179, 245, 0), (219, 245, 0), (255, 247, 0), (255, 239, 0), (255, 231, 0), (255, 223, 0), (255, 213, 0), (255, 191, 0), (253, 167, 0), (255, 143, 8), (255, 122, 16), (255, 104, 16), (255, 90, 16), (255, 78, 16), (255, 66, 16), (253, 58, 8), (241, 26, 0), (241, 26, 0), (241, 26, 0), (201, 14, 9), (201, 14, 9), (201, 14, 9), (160, 6, 31), (160, 6, 31), (160, 6, 31), (161, 8, 56), (161, 8, 56), (148, 11, 54), (148, 11, 54), (108, 10, 31), (108, 10, 31), (108, 10, 31), (82, 13, 12)]

WIND_LEVELS = [0, 2, 5, 10, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 91, 94, 97, 100, 105, 110, 115, 120]
WIND_RGB_VALUES = [(150, 210, 250), (120, 185, 250), (80, 165, 245), (60, 150, 245), (40, 130, 240), (30, 110, 235), (15, 160, 15), (30, 180, 30), (55, 210, 60), (120, 245, 115), (150, 245, 140), (180, 250, 170), (255, 232, 120), (255, 192, 60), (255, 160, 0), (255, 96, 0), (255, 50, 0), (225, 20, 0), (192, 0, 0), (165, 0, 0), (200, 60, 60), (230, 80, 80), (230, 112, 112), (230, 140, 140), (248, 160, 160), (255, 200, 200), (255, 230, 230), (219, 169, 199), (197, 129, 199), (175, 89, 199), (153, 49, 199), (131, 10, 199), (145, 145, 145), (167, 167, 167), (190, 190, 190), (212, 212, 212), (235, 235, 235)]


def mpl_rgb(name, n, reverse=False):
    cmap = matplotlib.colormaps.get_cmap(name).resampled(n)
    xs = np.linspace(0, 1, n)
    if reverse:
        xs = xs[::-1]
    return [(int(cmap(x)[0] * 255), int(cmap(x)[1] * 255), int(cmap(x)[2] * 255)) for x in xs]


TEMP_LEVELS = list(np.linspace(-40, 50, len(TEMP_RGB_VALUES) + 1))
TEMP_2M_LEVELS = list(np.linspace(-40, 50, len(TEMP_RGB_VALUES) + 1))
DEW_LEVELS = list(np.linspace(-30, 40, len(DEW_RGB_VALUES) + 1))
T850_LEVELS = [-30, -29, -28, -27, -26, -25, -24, -23, -22, -21, -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]

RH_LEVELS = list(np.linspace(0, 100, 41))
RH_COLORS = mpl_rgb("YlGnBu", len(RH_LEVELS) - 1)

CLOUD_LEVELS = list(np.arange(5, 105, 5))
CLOUD_COLORS = mpl_rgb("Greys", len(CLOUD_LEVELS) - 1)

CAPE_LEVELS = [10, 25, 50, 100, 200, 250, 300, 350, 400, 450, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2200, 2400, 2600, 2800, 3000, 4000]
CAPE_COLORS = [(150, 210, 250), (120, 185, 250), (80, 165, 245), (60, 150, 245), (40, 130, 240), (30, 110, 235), (15, 160, 15), (30, 180, 30), (55, 210, 60), (80, 240, 80), (120, 245, 115), (150, 245, 140), (180, 250, 170), (210, 255, 200), (255, 255, 120), (255, 230, 100), (255, 200, 80), (255, 170, 60), (255, 140, 40), (255, 110, 20), (255, 80, 0), (230, 60, 0), (200, 40, 0), (170, 20, 0), (170, 20, 0), (170, 20, 0), (170, 20, 0), (170, 20, 0), (170, 20, 0), (170, 20, 0), (170, 20, 0)]


CIN_LEVELS = [-950, -925, -900, -875, -850, -825, -800, -775, -750, -725,
              -700, -675, -650, -625, -600, -575, -550, -525, -500, -475,
              -450, -425, -400, -375, -350, -325, -300, -275, -250, -225,
              -200, -175, -150, -125, -100, -75, -50, -25, 0]
CIN_COLORS = mpl_rgb("jet", len(CIN_LEVELS) - 1, reverse=True)

SNOW_RATE_LEVELS = list(np.linspace(0.01, 5, 24))
SNOW_RATE_COLORS = mpl_rgb("PuBuGn", len(SNOW_RATE_LEVELS) - 1)

SNOW_DEPTH_LEVELS = list(np.linspace(1.0, 50.0, 50))
SNOW_DEPTH_COLORS = mpl_rgb("BuPu", len(SNOW_DEPTH_LEVELS) - 1)

LIGHTNING_LEVELS = [0.0, 1e-7, 2e-7, 3e-7, 4e-7, 5e-7, 6e-7, 7e-7, 8e-7, 9e-7, 10e-7]
LIGHTNING_COLORS = mpl_rgb("jet", len(LIGHTNING_LEVELS) - 1)

PRESSURE_LEVELS = list(np.arange(960, 1042, 2))

ACC_PRECIP_LEVELS = [0.2, 0.5, 1, 2, 4, 6, 8, 10, 15, 20, 30, 40, 60, 80, 100]
ACC_PRECIP_COLORS = PRECIP_RGB_LIST[:len(ACC_PRECIP_LEVELS) - 1]


# render:
#   "raster"         = no contour outlines, best for precip
#   "filled_contour" = smooth shaded weather field
#   "isobars"        = transparent tile with pressure contours only
# arrows:
#   true = add wind arrows from U/V components
VARIABLES = [
    {"key": "precip_rate", "label": "Precipitation rate", "units": "mm h⁻¹",
     "selector": {"shortName": "tprate", "aliases": ["tprate", "precipitation rate", "total precipitation rate", "rain rate"]}, "scale": 3600.0,
     "levels": PRECIP_LEVELS, "colors": PRECIP_RGB_LIST, "cmap_under": "none",
     "opacity": 0.98, "mask_below": 0.1, "render": "raster"},

    {"key": "accumulated_precip", "label": "Rain accumulation", "units": "mm",
     # Built after the GRIB is read by summing each available precip_rate step through the run.
     # This avoids showing isolated rain-rate data for an accumulation layer.
     "selector": {"shortName": "tp", "level": 0},
     "derived_from_rate": "precip_rate",
     "levels": ACC_PRECIP_LEVELS, "colors": ACC_PRECIP_COLORS, "cmap_under": "none",
     "opacity": 0.95, "mask_below": 0.2, "render": "raster"},

    {"key": "snow_rate", "label": "Snow rate", "units": "mm h⁻¹",
     "selector": {"shortName": "lssfr"}, "scale": 3600.0,
     "levels": SNOW_RATE_LEVELS, "colors": SNOW_RATE_COLORS, "cmap_under": "none",
     "opacity": 0.95, "mask_below": 0.01, "render": "raster"},

    {"key": "snowfall_rate", "label": "Snowfall rate", "units": "mm h⁻¹",
     "selector": {"shortName": "tsrwe", "typeOfLevel": "surface"}, "scale": 3600.0,
     "levels": SNOW_RATE_LEVELS, "colors": SNOW_RATE_COLORS, "cmap_under": "none",
     "opacity": 0.95, "mask_below": 0.01, "render": "raster"},

    {"key": "snow_depth", "label": "Snow depth", "units": "cm",
     "selector": {"shortName": "sd"}, "scale": 100.0,
     "levels": SNOW_DEPTH_LEVELS, "colors": SNOW_DEPTH_COLORS, "cmap_under": "none",
     "opacity": 0.90, "mask_below": 1.0, "render": "filled_contour"},

    {"key": "temp_2m", "label": "2 m temperature", "units": "°C",
     "selector": {"name": "2 metre temperature", "aliases": ["2 metre temperature", "2m temperature", "2t"]}, "offset": -273.15,
     "levels": TEMP_2M_LEVELS, "colors": TEMP_RGB_VALUES, "opacity": 0.82,
     "render": "filled_contour", "values": False},

    {"key": "max_temp_2m", "label": "Maximum temperature", "units": "°C",
     "selector": {"name": "Maximum temperature at 2 metres", "aliases": ["maximum temperature at 2 metres", "2 metre maximum temperature", "maximum 2m temperature", "mx2t"]}, "offset": -273.15,
     "levels": TEMP_2M_LEVELS, "colors": TEMP_RGB_VALUES, "opacity": 0.82,
     "render": "filled_contour", "values": True},

    {"key": "dewpoint_2m", "label": "2 m dew point", "units": "°C",
     "selector": {"name": "2 metre dewpoint temperature", "aliases": ["2 metre dewpoint temperature", "2 metre dewpoint", "2m dewpoint", "2d", "dew point", "dewpoint"]}, "offset": -273.15,
     "levels": DEW_LEVELS, "colors": DEW_RGB_VALUES, "opacity": 0.82,
     "render": "filled_contour", "values": False},

    {"key": "surface_temp", "label": "Surface temperature", "units": "°C",
     "selector": {"name": "Temperature", "typeOfLevel": "surface", "level": 0}, "offset": -273.15,
     "levels": TEMP_LEVELS, "colors": TEMP_RGB_VALUES, "opacity": 0.82,
     "render": "filled_contour", "values": True},

    {"key": "temp_850", "label": "850 hPa temperature", "units": "°C",
     "selector": {"name": "Temperature", "typeOfLevel": "isobaricInhPa", "level": 850}, "offset": -273.15,
     "levels": T850_LEVELS, "colors": T850_RGB_VALUES, "opacity": 0.82,
     "render": "filled_contour", "values": False},

    {"key": "relative_humidity_2m", "label": "2 m relative humidity", "units": "%",
     "selector": {"shortName": "2r", "typeOfLevel": "heightAboveGround", "level": 2, "aliases": ["2r", "2 metre relative humidity", "2m relative humidity", "relative humidity"]},
     "levels": RH_LEVELS, "colors": RH_COLORS, "opacity": 0.82, "render": "filled_contour", "values": True},

    {"key": "total_cloud", "label": "Total cloud cover", "units": "%",
     "selector": {"parameterCategory": 6, "parameterNumber": 1, "aliases": ["total cloud cover", "cloud cover", "tcc"]},
     "levels": CLOUD_LEVELS, "colors": CLOUD_COLORS, "cmap_under": "none",
     "opacity": 0.70, "mask_below": 5, "render": "filled_contour"},

    {"key": "wind_gust", "label": "Wind gust", "units": "mph",
     "selector": {"name": "Instantaneous 10 metre wind gust", "aliases": ["instantaneous 10 metre wind gust", "wind gust", "gust", "10fg", "i10fg"]}, "scale": 2.23694,
     "levels": WIND_LEVELS, "colors": WIND_RGB_VALUES, "cmap_under": "none",
     "opacity": 0.88, "mask_below": 2, "render": "filled_contour", "arrows": True, "values": False},

    {"key": "pressure_msl", "label": "Mean sea-level pressure", "units": "hPa",
     "selector": {"shortName": "prmsl", "aliases": ["prmsl", "mean sea level pressure", "mean sea-level pressure", "msl", "pressure"]}, "scale": 0.01,
     "levels": PRESSURE_LEVELS, "colors": [(255,255,255)] * (len(PRESSURE_LEVELS) - 1),
     "opacity": 1.0, "render": "isobars"},

    {"key": "temp_850_pressure", "label": "850 hPa temperature + Pressure", "units": "°C / hPa",
     "overlay_of": ["temp_850", "pressure_msl"],
     "levels": T850_LEVELS, "colors": T850_RGB_VALUES, "opacity": 0.82,
     "render": "overlay"},

    {"key": "precip_rate_pressure", "label": "Precipitation rate + Pressure", "units": "mm h⁻¹ / hPa",
     "overlay_of": ["precip_rate", "pressure_msl"],
     "levels": PRECIP_LEVELS, "colors": PRECIP_RGB_LIST, "opacity": 0.98,
     "render": "overlay"},

    {"key": "lightning", "label": "Lightning potential", "units": "index",
     "selector": {"parameterCategory": 17, "parameterNumber": 192, "aliases": ["lightning", "lightning density", "lightning potential"]},
     "levels": LIGHTNING_LEVELS, "colors": LIGHTNING_COLORS, "cmap_under": "none",
     "opacity": 0.95, "mask_below": 1e-7, "render": "filled_contour"},

    {"key": "ml_cape", "label": "ML CAPE", "units": "J kg⁻¹",
     "selector": {"parameterCategory": 7, "parameterNumber": 193, "typeOfLevel": "surface", "aliases": ["ml cape", "mixed layer cape", "convective available potential energy"]},
     "levels": CAPE_LEVELS, "colors": CAPE_COLORS, "cmap_under": "none",
     "opacity": 0.90, "mask_below": 10, "render": "filled_contour"},

    {"key": "mu_cape", "label": "MU CAPE", "units": "J kg⁻¹",
     "selector": {"parameterCategory": 7, "parameterNumber": 192, "typeOfLevel": "surface", "aliases": ["mu cape", "most unstable cape", "cape"]},
     "levels": CAPE_LEVELS, "colors": CAPE_COLORS, "cmap_under": "none",
     "opacity": 0.90, "mask_below": 10, "render": "filled_contour"},

    {"key": "sb_cape", "label": "SB CAPE", "units": "J kg⁻¹",
     "selector": {"parameterCategory": 7, "parameterNumber": 6, "typeOfLevel": "surface", "aliases": ["sb cape", "surface based cape", "cape"]},
     "levels": CAPE_LEVELS, "colors": CAPE_COLORS, "cmap_under": "none",
     "opacity": 0.90, "mask_below": 10, "render": "filled_contour"},

    {"key": "mu_cin", "label": "MU CIN", "units": "J kg⁻¹",
     "selector": {"parameterCategory": 7, "parameterNumber": 194, "typeOfLevel": "surface", "aliases": ["mu cin", "most unstable cin", "convective inhibition"]},
     "levels": CIN_LEVELS, "colors": CIN_COLORS, "cmap_under": "none",
     "opacity": 0.90, "mask_equal": 0, "render": "filled_contour"},

    {"key": "sb_cin", "label": "SB CIN", "units": "J kg⁻¹",
     "selector": {"parameterCategory": 7, "parameterNumber": 7, "typeOfLevel": "surface", "aliases": ["sb cin", "surface based cin", "cin"]},
     "levels": CIN_LEVELS, "colors": CIN_COLORS, "cmap_under": "none",
     "opacity": 0.90, "mask_equal": 0, "render": "filled_contour"},
]

DERIVED_VARIABLES = [
    {"key": "wind_speed_10m", "label": "10 m wind speed", "units": "mph", "kind": "wind_speed",
     "levels": WIND_LEVELS, "colors": WIND_RGB_VALUES, "cmap_under": "none",
     "opacity": 0.88, "mask_below": 2, "render": "filled_contour", "arrows": True, "values": True},
    {"key": "wind_chill", "label": "Wind chill", "units": "°C", "kind": "wind_chill",
     "levels": TEMP_LEVELS, "colors": TEMP_RGB_VALUES, "opacity": 0.82,
     "render": "filled_contour", "values": True},
]


# ── ACTIVE LAYERS ───────────────────────────────────────────────────────────
# Only these requested variables will be generated.
ACTIVE_LAYER_ORDER = [
    "temp_2m",              # Temperature shaded only; no text values
    "temp_850_pressure",    # 850 hPa temperature shaded with sea-level pressure
    "pressure_msl",         # Sea-level pressure
    "precip_rate",          # Rainfall
    "precip_rate_pressure", # Rainfall and sea-level pressure
    "dewpoint_2m",          # Dewpoint shaded only; no text values
    "wind_gust",            # Wind gusts shaded with wind arrows only
    "total_cloud",          # Total cloud cover
]
ACTIVE_LAYER_KEYS = set(ACTIVE_LAYER_ORDER)
ACTIVE_LAYER_RANK = {key: i for i, key in enumerate(ACTIVE_LAYER_ORDER)}

COMPOSITE_VARIABLES = [v for v in VARIABLES if v.get("render") == "overlay" and v["key"] in ACTIVE_LAYER_KEYS]
VARIABLES = [v for v in VARIABLES if v["key"] in ACTIVE_LAYER_KEYS and v.get("render") != "overlay"]
DERIVED_VARIABLES = [v for v in DERIVED_VARIABLES if v["key"] in ACTIVE_LAYER_KEYS]

VARIABLES.sort(key=lambda v: ACTIVE_LAYER_RANK.get(v["key"], 999))
DERIVED_VARIABLES.sort(key=lambda v: ACTIVE_LAYER_RANK.get(v["key"], 999))
COMPOSITE_VARIABLES.sort(key=lambda v: ACTIVE_LAYER_RANK.get(v["key"], 999))



WEBMERC_MAX = 20037508.342789244
to_merc = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True)


def lonlat_to_tile(lon, lat, z):
    lat = max(min(lat, 85.05112878), -85.05112878)
    n = 2 ** z
    x = int((lon + 180.0) / 360.0 * n)
    y = int((1.0 - math.asinh(math.tan(math.radians(lat))) / math.pi) / 2.0 * n)
    return max(0, min(n - 1, x)), max(0, min(n - 1, y))


def tile_bounds_merc(x, y, z):
    n = 2 ** z
    span = 2 * WEBMERC_MAX / n
    minx = -WEBMERC_MAX + x * span
    maxx = minx + span
    maxy = WEBMERC_MAX - y * span
    miny = maxy - span
    return minx, miny, maxx, maxy


def safe_float(v):
    return float(np.asarray(v).ravel()[0])


def iso_or_none(dt):
    if not dt:
        return None
    try:
        return dt.isoformat()
    except Exception:
        return None

def get_grib_run_datetime(grb):
    """
    Return the model analysis/run datetime from a GRIB message.

    pygrib files can expose this differently depending on centre/template.
    Prefer analDate, then dataDate/dataTime, then validityDate/validityTime minus forecastTime.
    """
    dt = safe_grib_attr(grb, "analDate", None)
    if dt is not None:
        return dt

    try:
        data_date = int(safe_grib_attr(grb, "dataDate"))
        data_time = int(safe_grib_attr(grb, "dataTime", 0))
        yyyy = data_date // 10000
        mm = (data_date // 100) % 100
        dd = data_date % 100
        hh = data_time // 100
        minute = data_time % 100
        from datetime import datetime
        return datetime(yyyy, mm, dd, hh, minute)
    except Exception:
        pass

    try:
        validity_date = int(safe_grib_attr(grb, "validityDate"))
        validity_time = int(safe_grib_attr(grb, "validityTime", 0))
        yyyy = validity_date // 10000
        mm = (validity_date // 100) % 100
        dd = validity_date % 100
        hh = validity_time // 100
        minute = validity_time % 100
        from datetime import datetime, timedelta
        valid_dt = datetime(yyyy, mm, dd, hh, minute)
        return valid_dt - timedelta(hours=int(safe_grib_attr(grb, "forecastTime", 0)))
    except Exception:
        pass

    return None






def safe_grib_attr(grb, attr, default=None):
    """Safely read pygrib attributes/keys.

    pygrib can raise RuntimeError for missing keys even when using getattr().
    This helper returns default instead, so non-data/unsupported messages are
    skipped cleanly.
    """
    try:
        return getattr(grb, attr)
    except Exception:
        pass

    try:
        return grb[attr]
    except Exception:
        return default


def get_forecast_hour(grb, default=-999):
    """Safely read forecastTime from a pygrib message.

    Some ECM GRIB messages do not expose forecastTime; pygrib can raise
    RuntimeError instead of returning the getattr default. Those messages are
    skipped by returning default.
    """
    try:
        return int(grb.forecastTime)
    except Exception:
        pass

    for key in ("stepRange", "endStep"):
        try:
            value = grb[key]
            if value is None:
                continue
            text = str(value)
            if "-" in text:
                text = text.split("-")[-1]
            return int(float(text))
        except Exception:
            continue

    return int(default)


def norm_text(value):
    return str(value or "").strip().lower()


def grb_text_blob(grb):
    parts = [
        safe_grib_attr(grb, "shortName", ""),
        safe_grib_attr(grb, "name", ""),
        safe_grib_attr(grb, "parameterName", ""),
        safe_grib_attr(grb, "cfName", ""),
        safe_grib_attr(grb, "typeOfLevel", ""),
    ]
    return " ".join(norm_text(p) for p in parts)


def match_selector(grb, selector):
    """
    Flexible but safe GRIB matcher.

    Exact selector attributes are checked before aliases so temperature does
    not accidentally pick dewpoint just because dewpoint includes "temperature".
    """
    exact_items = [(attr, expected) for attr, expected in selector.items() if attr != "aliases"]

    if exact_items:
        exact_ok = True
        for attr, expected in exact_items:
            actual = safe_grib_attr(grb, attr)
            if isinstance(expected, (list, tuple, set)):
                if actual not in expected:
                    exact_ok = False
                    break
            else:
                if actual != expected:
                    exact_ok = False
                    break
        if exact_ok:
            return True

    aliases = selector.get("aliases")
    if aliases:
        blob = grb_text_blob(grb)
        for alias in aliases:
            if norm_text(alias) in blob:
                return True

    return False


def write_grib_inventory(grib_path, inventory_path, max_messages=1000):
    lines = []
    lines.append(f"GRIB inventory for: {grib_path}")
    lines.append("idx | forecastTime | shortName | name | typeOfLevel | level | paramCat | paramNo | dataDate | dataTime")
    lines.append("-" * 140)

    try:
        with pygrib.open(str(grib_path)) as inv:
            for idx, grb in enumerate(inv, start=1):
                if idx > max_messages:
                    lines.append(f"... stopped at {max_messages} messages")
                    break
                lines.append(
                    f"{idx} | "
                    f"{getattr(grb, 'forecastTime', '')} | "
                    f"{getattr(grb, 'shortName', '')} | "
                    f"{getattr(grb, 'name', '')} | "
                    f"{getattr(grb, 'typeOfLevel', '')} | "
                    f"{getattr(grb, 'level', '')} | "
                    f"{getattr(grb, 'parameterCategory', '')} | "
                    f"{getattr(grb, 'parameterNumber', '')} | "
                    f"{getattr(grb, 'dataDate', '')} | "
                    f"{getattr(grb, 'dataTime', '')}"
                )
    except Exception as exc:
        lines.append(f"Could not read GRIB inventory: {exc}")

    inventory_path.write_text("\n".join(lines), encoding="utf-8")
    print(f"🧾 Wrote GRIB inventory: {inventory_path}", flush=True)


def build_cmap(variable):
    colors = variable["colors"]
    cmap = ListedColormap([(r / 255, g / 255, b / 255) for r, g, b in colors])
    if variable.get("cmap_under") == "none":
        cmap.set_under("none", alpha=0)
    cmap.set_bad((0, 0, 0, 0))
    norm = BoundaryNorm(variable["levels"], cmap.N, clip=True)
    return cmap, norm


def convert_values(grb, variable, apply_mask=True):
    data = grb.values.astype(float)
    data = data * variable.get("scale", 1.0) + variable.get("offset", 0.0)
    if apply_mask:
        if "mask_below" in variable:
            data = np.ma.masked_where(data < variable["mask_below"], data)
        if "mask_equal" in variable:
            data = np.ma.masked_where(data == variable["mask_equal"], data)
    return data


def field_from_msg(grb, variable):
    lats, lons = grb.latlons()
    raw_values = convert_values(grb, variable, apply_mask=False)
    return {
        "forecastTime": get_forecast_hour(grb),
        "validTime": iso_or_none(safe_grib_attr(grb, "validDate", None)),
        "lats": lats,
        "lons": lons,
        "values": convert_values(grb, variable, apply_mask=True),
        "rawValues": raw_values,
        "sourceShortName": safe_grib_attr(grb, "shortName", None),
        "sourceName": safe_grib_attr(grb, "name", None),
        "sourceTypeOfLevel": safe_grib_attr(grb, "typeOfLevel", None),
        "sourceLevel": safe_grib_attr(grb, "level", None),
    }


def calc_wind_chill(temp_c, wind_mph):
    wind_kmh = np.maximum(wind_mph * 1.609344, 4.8)
    return 13.12 + 0.6215 * temp_c - 11.37 * np.power(wind_kmh, 0.16) + 0.3965 * temp_c * np.power(wind_kmh, 0.16)


def add_value_labels(ax, mx, my, values, minx, miny, maxx, maxy, z, variable):
    """Add zoom-aware value labels."""
    arr = np.asarray(values)

    key = variable.get("key", "")
    is_humidity_style = key == "relative_humidity_2m"
    is_temperature_style = key in {
        "temp_2m",
        "dewpoint_2m",
        "surface_temp",
        "temp_850",
        "wind_chill",
        "relative_humidity_2m",
    }

    is_wind_style = key in {"wind_speed_10m", "wind_gust"}

    if is_humidity_style:
        # Humidity labels were visually too large; keep the same density but
        # make the text a little smaller at every generated zoom.
        if z <= 5:
            step = 30
            font_size = 5.8
        elif z == 6:
            step = 18
            font_size = 6.1
        else:
            step = 12
            font_size = 6.3
        stroke_width = 0

    elif is_temperature_style:
        # Black, readable labels. Higher zoom = smaller step = more labels.
        if z <= 5:
            step = 30
            font_size = 6.8
        elif z == 6:
            step = 18
            font_size = 7.2
        else:
            # z7: more values plotted.
            step = 7
            font_size = 7.8
        stroke_width = 0

    elif is_wind_style:
        # Wind values: smaller than temperature, but visible and denser with zoom.
        if z <= 5:
            step = 34
            font_size = 5.6
        elif z == 6:
            step = 24
            font_size = 5.8
        else:
            step = 16
            font_size = 6.0
        stroke_width = 1.2

    else:
        if z <= 5:
            step = 34
            font_size = 7.0
        elif z == 6:
            step = 26
            font_size = 7.5
        else:
            step = 20
            font_size = 8.0
        stroke_width = 1.8

    suffix = "%" if variable.get("units") == "%" else ""
    decimals = 0

    offset_i = max(0, step // 3)
    offset_j = max(0, step // 3)

    for i in range(offset_i, arr.shape[0], step):
        for j in range(offset_j, arr.shape[1], step):
            x = mx[i, j]
            y = my[i, j]
            if x < minx or x > maxx or y < miny or y > maxy:
                continue

            val = arr[i, j]
            if not np.isfinite(val):
                continue

            if is_temperature_style:
                if is_humidity_style:
                    ax.text(
                        x, y, f"{val:.{decimals}f}{suffix}",
                        color="white",
                        fontsize=font_size,
                        fontweight="bold",
                        fontfamily="DejaVu Sans",
                        ha="center",
                        va="center",
                        antialiased=True,
                        path_effects=[pe.withStroke(linewidth=0.65, foreground="#777777")]
                    )
                else:
                    ax.text(
                        x, y, f"{val:.{decimals}f}{suffix}",
                        color="black",
                        fontsize=font_size,
                        fontweight="bold",
                        fontfamily="DejaVu Sans",
                        ha="center",
                        va="center",
                        antialiased=True
                    )
            else:
                text_kwargs = dict(
                    color="white",
                    fontsize=font_size,
                    fontweight="bold",
                    fontfamily="DejaVu Sans",
                    ha="center",
                    va="center",
                    antialiased=True,
                )
                if stroke_width > 0:
                    text_kwargs["path_effects"] = [pe.withStroke(linewidth=stroke_width, foreground="black")]
                ax.text(
                    x, y, f"{val:.{decimals}f}{suffix}",
                    **text_kwargs
                )


def add_wind_arrows(ax, field, mx, my, minx, miny, maxx, maxy, step=18):
    u = field.get("u")
    v = field.get("v")
    if u is None or v is None:
        return
    u = np.asarray(u)
    v = np.asarray(v)
    xs = mx[::step, ::step]
    ys = my[::step, ::step]
    us = u[::step, ::step]
    vs = v[::step, ::step]
    mask = (xs >= minx) & (xs <= maxx) & (ys >= miny) & (ys <= maxy) & np.isfinite(us) & np.isfinite(vs)
    if not np.any(mask):
        return
    ax.quiver(
        xs[mask], ys[mask], us[mask], vs[mask],
        color="white",
        angles="xy",
        scale_units="xy",
        # Higher scale value = shorter/smaller arrows.
        scale=0.00075,
        width=0.0032,
        headwidth=2.5,
        headlength=3.0,
        headaxislength=2.7,
        alpha=0.92,
        path_effects=[pe.withStroke(linewidth=0.55, foreground="black")]
    )


if not grib_file.exists():
    raise SystemExit(f"❌ GRIB file not found: {grib_file}")

all_variables = VARIABLES + DERIVED_VARIABLES + COMPOSITE_VARIABLES
data_by_variable = {v["key"]: {} for v in all_variables}

u10_by_hour = {}
v10_by_hour = {}
t2m_by_hour = {}
first_message = None
run_time_candidates = []

print(f"🔎 Reading GRIB file: {grib_file}")


with pygrib.open(str(grib_file)) as grbs:
    for grb in grbs:
        if first_message is None:
            first_message = grb

        detected_run = get_grib_run_datetime(grb)
        if detected_run is not None:
            run_time_candidates.append(detected_run)

        hour = get_forecast_hour(grb)
        if hour < 0 or hour > MAX_FORECAST_HOUR or not forecast_hour_allowed(hour):
            continue

        if safe_grib_attr(grb, "parameterNumber") == 192 and safe_grib_attr(grb, "level") == 10:
            u10_by_hour.setdefault(hour, grb)
        if safe_grib_attr(grb, "parameterNumber") == 193 and safe_grib_attr(grb, "level") == 10:
            v10_by_hour.setdefault(hour, grb)
        if safe_grib_attr(grb, "name") == "2 metre temperature":
            t2m_by_hour.setdefault(hour, grb)

        for variable in VARIABLES:
            key = variable["key"]
            if hour in data_by_variable[key]:
                continue
            if match_selector(grb, variable["selector"]):
                data_by_variable[key][hour] = field_from_msg(grb, variable)


def build_cumulative_precip_from_rate(data_by_variable):
    """Build Rain accumulation as a cumulative total from precip_rate.

    precip_rate is in mm/h after scaling. For each available forecast step, add
    rate * hour_delta to the running total, so T+24 represents the total
    rain from the start of the model run through forecast hour 24.
    """
    if "accumulated_precip" not in data_by_variable:
        return

    rate_fields = data_by_variable.get("precip_rate", {})
    if not rate_fields:
        return

    cumulative = None
    previous_hour = None
    built = {}

    for hour in sorted(rate_fields):
        field = rate_fields[hour]
        rate = np.ma.filled(field.get("rawValues", field["values"]), 0.0).astype(float)
        rate = np.where(np.isfinite(rate), np.maximum(rate, 0.0), 0.0)

        if previous_hour is None:
            # T+0 should not add a future hour. If the first available field is
            # T+1 or later, use that first hour span so the total still starts
            # correctly for GRIBs without a T+0 precip-rate message.
            hour_delta = max(0, int(hour))
        else:
            hour_delta = max(0, int(hour) - int(previous_hour))

        if cumulative is None:
            cumulative = np.zeros_like(rate, dtype=float)

        cumulative = cumulative + (rate * hour_delta)
        masked_total = np.ma.masked_where(cumulative < 0.2, cumulative.copy())

        built[hour] = {
            "forecastTime": int(hour),
            "validTime": field.get("validTime"),
            "lats": field["lats"],
            "lons": field["lons"],
            "values": masked_total,
            "rawValues": cumulative.copy(),
            "sourceShortName": "tprate cumulative",
            "sourceName": "Rain accumulation from precipitation rate",
            "sourceTypeOfLevel": field.get("sourceTypeOfLevel"),
            "sourceLevel": field.get("sourceLevel"),
        }
        previous_hour = hour

    # Prefer the cumulative layer over any native tp messages, because this
    # explicitly matches the viewer expectation: running total through the run.
    data_by_variable["accumulated_precip"] = built



# Build the cumulative rain layer from the available precipitation-rate time steps.
build_cumulative_precip_from_rate(data_by_variable)

def build_overlay_variable(data_by_variable, composite_variable):
    """Build a composite layer by carrying base-field metadata and rendering with pressure overlaid."""
    base_key, overlay_key = composite_variable.get("overlay_of", [None, None])
    base_fields = data_by_variable.get(base_key, {})
    overlay_fields = data_by_variable.get(overlay_key, {})
    if not base_fields or not overlay_fields:
        return

    built = {}
    for hour in sorted(set(base_fields) & set(overlay_fields)):
        base = base_fields[hour]
        overlay = overlay_fields[hour]
        built[hour] = dict(base)
        built[hour]["overlayField"] = overlay
        built[hour]["baseKey"] = base_key
        built[hour]["overlayKey"] = overlay_key
    data_by_variable[composite_variable["key"]] = built


for composite_variable in COMPOSITE_VARIABLES:
    build_overlay_variable(data_by_variable, composite_variable)

# Attach wind vectors to gust fields where possible.
for hour, field in data_by_variable.get("wind_gust", {}).items():
    if hour in u10_by_hour and hour in v10_by_hour:
        field["u"] = u10_by_hour[hour].values.astype(float) * 2.23694
        field["v"] = v10_by_hour[hour].values.astype(float) * 2.23694

# Derived wind speed.
# Only build this if wind_speed_10m is enabled in ACTIVE_LAYER_KEYS.
if "wind_speed_10m" in data_by_variable:
    for hour in sorted(set(u10_by_hour) & set(v10_by_hour)):
        u = u10_by_hour[hour]
        v = v10_by_hour[hour]
        lats, lons = u.latlons()
        u_mph = u.values.astype(float) * 2.23694
        v_mph = v.values.astype(float) * 2.23694
        speed_mph = np.sqrt(u_mph ** 2 + v_mph ** 2)
        speed_mph = np.ma.masked_where(speed_mph < 2, speed_mph)
        data_by_variable["wind_speed_10m"][hour] = {
            "forecastTime": hour,
            "validTime": iso_or_none(getattr(u, "validDate", None)),
            "lats": lats,
            "lons": lons,
            "values": speed_mph,
            "u": u_mph,
            "v": v_mph,
        }

# Derived wind chill.
# Only build this if wind_chill is enabled in ACTIVE_LAYER_KEYS.
if "wind_chill" in data_by_variable:
    for hour in sorted(set(u10_by_hour) & set(v10_by_hour) & set(t2m_by_hour)):
        u = u10_by_hour[hour]
        v = v10_by_hour[hour]
        t = t2m_by_hour[hour]
        lats, lons = t.latlons()
        wind_mph = np.sqrt((u.values.astype(float) * 2.23694) ** 2 + (v.values.astype(float) * 2.23694) ** 2)
        temp_c = t.values.astype(float) - 273.15
        wc = calc_wind_chill(temp_c, wind_mph)
        data_by_variable["wind_chill"][hour] = {
            "forecastTime": hour,
            "validTime": iso_or_none(getattr(t, "validDate", None)),
            "lats": lats,
            "lons": lons,
            "values": wc,
        }

available_variables = [v for v in all_variables if data_by_variable[v["key"]]]
available_variables.sort(key=lambda v: ACTIVE_LAYER_RANK.get(v["key"], 999))

if not available_variables:
    inventory_path = output_directory / "grib_inventory.txt"
    write_grib_inventory(grib_file, inventory_path)
    raise SystemExit(
        "❌ None of the configured variables were found in the GRIB file. "
        f"Inventory written to: {inventory_path}"
    )

# Use the most common detected run time across all messages.
# This avoids stale/odd first-message metadata.
if run_time_candidates:
    from collections import Counter
    run_time = Counter(run_time_candidates).most_common(1)[0][0]
else:
    run_time = get_grib_run_datetime(first_message) if first_message else None

run_time_iso = iso_or_none(run_time)
run_label = f"{run_time.hour:02d}z" if run_time else "Unknown run"

print(f"🕒 Model run detected for manifest: {run_label} ({run_time_iso})")
if run_time_candidates:
    unique_runs = sorted(set(run_time_candidates))
    print("🕒 GRIB run times seen:", ", ".join(dt.isoformat() for dt in unique_runs[:6]), flush=True)
print("📈 Variables found:")
for variable in available_variables:
    hours = sorted(data_by_variable[variable["key"]])
    first_field = data_by_variable[variable["key"]][hours[0]]
    print(
        f"   {variable['label']}: T+{min(hours)} → T+{max(hours)} "
        f"from shortName={first_field.get('sourceShortName')} "
        f"name={first_field.get('sourceName')} "
        f"level={first_field.get('sourceTypeOfLevel')}/{first_field.get('sourceLevel')}",
        flush=True
    )


selected_forecast_hours = sorted({
    int(hour)
    for variable in available_variables
    for hour in data_by_variable.get(variable["key"], {})
})
print(
    "⏱ Forecast hours selected for manifest/tiles: "
    + (", ".join(f"T+{h}" for h in selected_forecast_hours) if selected_forecast_hours else "none"),
    flush=True
)

# ── RUN-SAFE OUTPUT FOLDER ─────────────────────────────────────────────────
# Build the new model run in a private folder first. The live viewer keeps using
# the previous root metadata.json until every tile and the manifest are complete.
build_started_utc = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
if run_time:
    run_id_base = f"{run_time:%Y%m%d}_{run_time.hour:02d}z"
else:
    run_id_base = f"unknown_{build_started_utc}"

run_public_dir = runs_directory / run_id_base
if run_public_dir.exists():
    run_public_dir = runs_directory / f"{run_id_base}_{build_started_utc}"

run_build_dir = runs_directory / f".building_{run_public_dir.name}_{os.getpid()}"
if run_build_dir.exists():
    shutil.rmtree(run_build_dir)

tiles_directory = run_build_dir / "tiles"
tiles_directory.mkdir(parents=True, exist_ok=True)
tiles_base_url = f"runs/{run_public_dir.name}/tiles"

print(f"📁 Building private run folder: {run_build_dir}", flush=True)
print(f"📁 Will publish completed run as: {run_public_dir}", flush=True)

first_variable = available_variables[0]
first_hour = min(data_by_variable[first_variable["key"]])
sample = data_by_variable[first_variable["key"]][first_hour]

bounds = {
    "south": max(-85.0, safe_float(np.nanmin(sample["lats"])) - 0.15),
    "west": safe_float(np.nanmin(sample["lons"])) - 0.15,
    "north": min(85.0, safe_float(np.nanmax(sample["lats"])) + 0.15),
    "east": safe_float(np.nanmax(sample["lons"])) + 0.15,
}
center = [(bounds["south"] + bounds["north"]) / 2, (bounds["west"] + bounds["east"]) / 2]

# Live-upload safe:
# Tiles render into the private build folder above. The public viewer only
# switches to this folder after metadata.json is replaced at the end.



def render_first_file_precip_tile(variable, field, z, x_tile, y_tile, out_path):
    if SKIP_EXISTING_TILES and out_path.exists() and out_path.stat().st_size > 0:
        return False

    """
    Dedicated precipitation renderer matching the first uploaded working file:
      - same mx/my transform
      - same 256px tile figure
      - same contourf call
      - no pcolormesh
      - no collection edge edits
      - no additional smoothing/masking tricks
    """
    lats = field["lats"]
    lons = field["lons"]
    data = field["values"]

    mx, my = to_merc.transform(lons, lats)
    minx, miny, maxx, maxy = tile_bounds_merc(x_tile, y_tile, z)

    if (
        np.nanmax(mx) < minx - TILE_PAD_M or np.nanmin(mx) > maxx + TILE_PAD_M or
        np.nanmax(my) < miny - TILE_PAD_M or np.nanmin(my) > maxy + TILE_PAD_M
    ):
        return False

    cmap = ListedColormap([(r / 255, g / 255, b / 255) for r, g, b in PRECIP_RGB_LIST])
    cmap.set_under("none", alpha=0)
    norm = BoundaryNorm(PRECIP_LEVELS, cmap.N, clip=True)

    fig = plt.figure(figsize=(TILE_SIZE / TILE_DPI, TILE_SIZE / TILE_DPI), dpi=TILE_DPI)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_xlim(minx, maxx)
    ax.set_ylim(miny, maxy)
    ax.axis("off")
    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)

    try:
        ax.contourf(
            mx, my, data,
            levels=PRECIP_LEVELS,
            cmap=cmap,
            norm=norm,
            extend="max",
            antialiased=True
        )
    except Exception as e:
        plt.close(fig)
        print(
            f"⚠ Failed FIRST-FILE precip tile "
            f"h={field['forecastTime']} z={z} x={x_tile} y={y_tile}: {e}"
        )
        return False

    out_path.parent.mkdir(parents=True, exist_ok=True)

    # Keep the original smooth contourf quality, but remove the faint anti-aliased
    # transparent fringe/halo around precipitation areas.
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=TILE_DPI, transparent=True, pad_inches=0)
    plt.close(fig)
    buf.seek(0)

    img = Image.open(buf).convert("RGBA")
    arr = np.array(img)

    alpha = arr[:, :, 3]

    # Remove only very faint edge pixels. Main precipitation remains untouched.
    # Increase to 50-70 if you still see a halo; lower to 20-30 if edges look too clipped.
    HALO_ALPHA_CUTOFF = 42
    faint = alpha < HALO_ALPHA_CUTOFF
    arr[faint, 3] = 0

    # Slightly firm up remaining semi-transparent rain pixels so it does not look dull.
    semi = (arr[:, :, 3] >= HALO_ALPHA_CUTOFF) & (arr[:, :, 3] < 230)
    arr[semi, 3] = np.minimum(255, (arr[semi, 3].astype(np.float32) * 1.18)).astype(np.uint8)

    Image.fromarray(arr, "RGBA").save(out_path)
    return True


def render_tile(variable, field, z, x_tile, y_tile, out_path):
    if SKIP_EXISTING_TILES and out_path.exists() and out_path.stat().st_size > 0:
        return False

    if variable.get("render") == "first_file_precip" or variable.get("key") == "precip_rate":
        return render_first_file_precip_tile(variable, field, z, x_tile, y_tile, out_path)

    lats = field["lats"]
    lons = field["lons"]
    data = field["values"]

    mx, my = to_merc.transform(lons, lats)
    minx, miny, maxx, maxy = tile_bounds_merc(x_tile, y_tile, z)

    if (
        np.nanmax(mx) < minx - TILE_PAD_M or np.nanmin(mx) > maxx + TILE_PAD_M or
        np.nanmax(my) < miny - TILE_PAD_M or np.nanmin(my) > maxy + TILE_PAD_M
    ):
        return False

    fig = plt.figure(figsize=(TILE_SIZE / TILE_DPI, TILE_SIZE / TILE_DPI), dpi=TILE_DPI)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_xlim(minx, maxx)
    ax.set_ylim(miny, maxy)
    ax.axis("off")
    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)

    try:
        render_mode = variable.get("render", "filled_contour")

        if render_mode == "overlay":
            base_key = field.get("baseKey")
            if base_key == "precip_rate":
                cmap = ListedColormap([(r / 255, g / 255, b / 255) for r, g, b in PRECIP_RGB_LIST])
                cmap.set_under("none", alpha=0)
                norm = BoundaryNorm(PRECIP_LEVELS, cmap.N, clip=True)
                ax.contourf(mx, my, data, levels=PRECIP_LEVELS, cmap=cmap, norm=norm, extend="max", antialiased=True)
            else:
                cmap, norm = build_cmap(variable)
                ax.contourf(mx, my, data, levels=variable["levels"], cmap=cmap, norm=norm, extend="both", antialiased=True)
                if variable.get("values"):
                    add_value_labels(ax, mx, my, data, minx, miny, maxx, maxy, z, {**variable, "key": base_key, "values": True})

            overlay = field.get("overlayField")
            if overlay is not None:
                omx, omy = to_merc.transform(overlay["lons"], overlay["lats"])
                cs = ax.contour(omx, omy, overlay["values"], levels=PRESSURE_LEVELS, colors="white", linewidths=0.65, antialiased=True)
                ax.clabel(cs, fmt="%.0f", fontsize=9, inline=True, colors="white", inline_spacing=3, manual=False)
                for txt in ax.texts:
                    txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="black")])

        elif render_mode == "isobars":
            cs = ax.contour(
                mx, my, data,
                levels=variable["levels"],
                colors="white",
                linewidths=0.65,
                antialiased=True
            )
            ax.clabel(
                cs,
                fmt="%.0f",
                fontsize=9,
                inline=True,
                colors="white",
                inline_spacing=3,
                manual=False
            )
            for txt in ax.texts:
                txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="black")])

        else:
            cmap, norm = build_cmap(variable)

            if render_mode == "raster":
                ax.pcolormesh(
                    mx, my, data,
                    cmap=cmap,
                    norm=norm,
                    shading="gouraud",
                    edgecolors="none",
                    linewidth=0,
                    antialiased=True,
                    rasterized=True
                )

            elif render_mode == "original_precip":
                # EXACT original rain/precipitation tile style from ukv_multi_variable_tiles_webmap.py:
                # plain contourf, same levels/cmap/norm/extend behaviour, antialiased=True,
                # and no pcolormesh, no collection edge edits, no seam suppression.
                ax.contourf(
                    mx, my, data,
                    levels=variable["levels"],
                    cmap=cmap,
                    norm=norm,
                    extend=variable.get("extend", "max"),
                    antialiased=True
                )

            else:
                ax.contourf(
                    mx, my, data,
                    levels=variable["levels"],
                    cmap=cmap,
                    norm=norm,
                    extend="both",
                    antialiased=True
                )

            if variable.get("arrows"):
                add_wind_arrows(ax, field, mx, my, minx, miny, maxx, maxy)

            if variable.get("values"):
                add_value_labels(ax, mx, my, data, minx, miny, maxx, maxy, z, variable)

    except Exception as e:
        plt.close(fig)
        print(f"⚠ Failed tile {variable['key']} T+{field['forecastTime']} z={z} x={x_tile} y={y_tile}: {e}")
        return False

    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=TILE_DPI, transparent=True, pad_inches=0, facecolor=(0,0,0,0))
    plt.close(fig)
    return True



def render_zoom_tiles_fast(variable, field, z):
    """
    Fast path: render the whole variable/hour/zoom image once, then crop into XYZ tiles.

    This avoids doing contourf/text rendering separately for every tile, which is the
    main reason CPU usage looked low and processing took so long.
    """
    key = variable["key"]
    hour = field["forecastTime"]

    x_min, y_max = lonlat_to_tile(bounds["west"], bounds["south"], z)
    x_max, y_min = lonlat_to_tile(bounds["east"], bounds["north"], z)

    tile_paths = [
        tiles_directory / key / f"{hour}" / f"{z}" / f"{x}" / f"{y}.png"
        for x in range(x_min, x_max + 1)
        for y in range(y_min, y_max + 1)
    ]

    if SKIP_EXISTING_TILES and tile_paths and all(p.exists() and p.stat().st_size > 0 for p in tile_paths):
        return 0

    nx_tiles = x_max - x_min + 1
    ny_tiles = y_max - y_min + 1
    width_px = nx_tiles * TILE_SIZE
    height_px = ny_tiles * TILE_SIZE

    minx_all, _, _, maxy_all = tile_bounds_merc(x_min, y_min, z)
    _, miny_all, maxx_all, _ = tile_bounds_merc(x_max, y_max, z)

    lats = field["lats"]
    lons = field["lons"]
    data = field["values"]
    mx, my = to_merc.transform(lons, lats)

    if (
        np.nanmax(mx) < minx_all - TILE_PAD_M or np.nanmin(mx) > maxx_all + TILE_PAD_M or
        np.nanmax(my) < miny_all - TILE_PAD_M or np.nanmin(my) > maxy_all + TILE_PAD_M
    ):
        return 0

    fig = plt.figure(figsize=(width_px / TILE_DPI, height_px / TILE_DPI), dpi=TILE_DPI)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_xlim(minx_all, maxx_all)
    ax.set_ylim(miny_all, maxy_all)
    ax.axis("off")
    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)

    try:
        render_mode = variable.get("render", "filled_contour")

        if variable.get("key") == "precip_rate":
            # Keep the dedicated first-file precipitation look.
            cmap = ListedColormap([(r / 255, g / 255, b / 255) for r, g, b in PRECIP_RGB_LIST])
            cmap.set_under("none", alpha=0)
            norm = BoundaryNorm(PRECIP_LEVELS, cmap.N, clip=True)
            ax.contourf(
                mx, my, data,
                levels=PRECIP_LEVELS,
                cmap=cmap,
                norm=norm,
                extend="max",
                antialiased=True
            )

        elif render_mode == "overlay":
            base_key = field.get("baseKey")
            if base_key == "precip_rate":
                cmap = ListedColormap([(r / 255, g / 255, b / 255) for r, g, b in PRECIP_RGB_LIST])
                cmap.set_under("none", alpha=0)
                norm = BoundaryNorm(PRECIP_LEVELS, cmap.N, clip=True)
                ax.contourf(
                    mx, my, data,
                    levels=PRECIP_LEVELS,
                    cmap=cmap,
                    norm=norm,
                    extend="max",
                    antialiased=True
                )
            else:
                cmap, norm = build_cmap(variable)
                ax.contourf(
                    mx, my, data,
                    levels=variable["levels"],
                    cmap=cmap,
                    norm=norm,
                    extend="both",
                    antialiased=True
                )
                if variable.get("values"):
                    add_value_labels(ax, mx, my, data, minx_all, miny_all, maxx_all, maxy_all, z, {**variable, "key": base_key, "values": True})

            overlay = field.get("overlayField")
            if overlay is not None:
                omx, omy = to_merc.transform(overlay["lons"], overlay["lats"])
                cs = ax.contour(
                    omx, omy, overlay["values"],
                    levels=PRESSURE_LEVELS,
                    colors="white",
                    linewidths=0.65,
                    antialiased=True
                )
                ax.clabel(
                    cs,
                    fmt="%.0f",
                    fontsize=9,
                    inline=True,
                    colors="white",
                    inline_spacing=3,
                    manual=False
                )
                for txt in ax.texts:
                    txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="black")])

        elif render_mode == "isobars":
            cs = ax.contour(
                mx, my, data,
                levels=variable["levels"],
                colors="white",
                linewidths=0.65,
                antialiased=True
            )
            ax.clabel(
                cs,
                fmt="%.0f",
                fontsize=9,
                inline=True,
                colors="white",
                inline_spacing=3,
                manual=False
            )
            for txt in ax.texts:
                txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="black")])

        else:
            cmap, norm = build_cmap(variable)

            if render_mode == "raster":
                ax.pcolormesh(
                    mx, my, data,
                    cmap=cmap,
                    norm=norm,
                    shading="gouraud",
                    edgecolors="none",
                    linewidth=0,
                    antialiased=True,
                    rasterized=True
                )
            else:
                ax.contourf(
                    mx, my, data,
                    levels=variable["levels"],
                    cmap=cmap,
                    norm=norm,
                    extend="both",
                    antialiased=True
                )

            if variable.get("arrows"):
                add_wind_arrows(ax, field, mx, my, minx_all, miny_all, maxx_all, maxy_all)

            if variable.get("values"):
                add_value_labels(ax, mx, my, data, minx_all, miny_all, maxx_all, maxy_all, z, variable)

        buf = io.BytesIO()
        fig.savefig(buf, format="png", dpi=TILE_DPI, transparent=True, pad_inches=0, facecolor=(0, 0, 0, 0))
        plt.close(fig)
        buf.seek(0)

        img = Image.open(buf).convert("RGBA")
        arr = np.array(img)

        # Slightly brighten overlay colours while preserving alpha/transparency.
        if OVERLAY_BRIGHTNESS != 1.0:
            rgb = arr[:, :, :3].astype(np.float32)
            alpha_mask = arr[:, :, 3] > 0
            rgb[alpha_mask] = np.clip(rgb[alpha_mask] * OVERLAY_BRIGHTNESS, 0, 255)
            arr[:, :, :3] = rgb.astype(np.uint8)
            img = Image.fromarray(arr, "RGBA")

        # Precipitation halo cleanup, kept from the previous good version.
        if variable.get("key") in {"precip_rate", "precip_rate_pressure"}:
            alpha = arr[:, :, 3]
            halo_alpha_cutoff = 42
            faint = alpha < halo_alpha_cutoff
            arr[faint, 3] = 0
            semi = (arr[:, :, 3] >= halo_alpha_cutoff) & (arr[:, :, 3] < 230)
            arr[semi, 3] = np.minimum(255, (arr[semi, 3].astype(np.float32) * 1.18)).astype(np.uint8)
            img = Image.fromarray(arr, "RGBA")

        written = 0
        for x in range(x_min, x_max + 1):
            for y in range(y_min, y_max + 1):
                out_path = tiles_directory / key / f"{hour}" / f"{z}" / f"{x}" / f"{y}.png"

                if SKIP_EXISTING_TILES and out_path.exists() and out_path.stat().st_size > 0:
                    continue

                left = (x - x_min) * TILE_SIZE
                upper = (y - y_min) * TILE_SIZE
                tile = img.crop((left, upper, left + TILE_SIZE, upper + TILE_SIZE))

                # Avoid saving fully transparent blank tiles.
                if not tile.getbbox():
                    continue

                out_path.parent.mkdir(parents=True, exist_ok=True)
                tile.save(out_path)
                written += 1

        return written

    except Exception as e:
        plt.close(fig)
        print(f"⚠ Failed FAST render {variable['key']} T+{hour} z={z}: {e}", flush=True)
        return 0


def clean_legend_levels(levels):
    """Round metadata legend values so the HTML legend does not show long decimals."""
    cleaned = []
    for value in levels:
        value = float(value)
        if abs(value - round(value)) < 1e-9:
            cleaned.append(int(round(value)))
        else:
            cleaned.append(round(value, 1))
    return cleaned


metadata_variables = []
valid_times_global = {}
total_tiles = 0

for variable_index, variable in enumerate(available_variables, start=1):
    key = variable["key"]
    hours = sorted(data_by_variable[key])

    variable_valid_times = {
        str(hour): data_by_variable[key][hour].get("validTime")
        for hour in hours
        if data_by_variable[key][hour].get("validTime")
    }

    for hour, vt in variable_valid_times.items():
        valid_times_global.setdefault(hour, vt)

    metadata_variables.append({
        "key": key,
        "label": variable["label"],
        "units": variable["units"],
        "hours": hours,
        "validTimes": variable_valid_times,
        "opacity": variable.get("opacity", 0.85),
        "levels": clean_legend_levels(variable["levels"]),
        "colors": variable["colors"],
    })

    print(f"🧱 Rendering {variable_index}/{len(available_variables)}: {variable['label']}", flush=True)

    max_workers = get_max_workers()

    for hour in hours:
        field = data_by_variable[key][hour]
        print(f"   ⏱ {variable['label']} T+{hour:02d}", flush=True)

        made_this_hour = 0

        # Fast renderer: render once per zoom, then crop tiles.
        for z in range(MIN_ZOOM, MAX_ZOOM + 1):
            written = render_zoom_tiles_fast(variable, field, z)
            total_tiles += written
            made_this_hour += written
            print(
                f"      z{z} complete ({written} tiles written)",
                flush=True
            )

        print(
            f"      hour complete ({made_this_hour} tiles written)",
            flush=True
        )

    print(f"✅ Finished {variable['label']}", flush=True)

metadata = {
    "minZoom": MIN_ZOOM,
    "maxZoom": MAX_ZOOM,
    "defaultZoom": DEFAULT_ZOOM,
    "bounds": [[bounds["south"], bounds["west"]], [bounds["north"], bounds["east"]]],
    "center": center,
    "runId": run_public_dir.name,
    "runTime": run_time_iso,
    "runLabel": run_label,
    "tilesBaseUrl": tiles_base_url,
    "completedAt": datetime.now(timezone.utc).isoformat(),
    "validTimes": valid_times_global,
    "variables": metadata_variables,
}

# Keep a copy of the manifest inside the completed run folder.
(run_build_dir / "metadata.json").write_text(json.dumps(metadata, indent=2), encoding="utf-8")
run_build_dir.replace(run_public_dir)

# Finally update the root manifest atomically. Until this replace happens, the
# viewer keeps using the previous completed run.
metadata_path = output_directory / "metadata.json"
metadata_tmp_path = output_directory / "metadata.json.tmp"
metadata_tmp_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
metadata_tmp_path.replace(metadata_path)

print(f"✅ Tile generation complete: {total_tiles:,} tiles")
print(f"✅ Published completed run: {run_public_dir}")
print(f"✅ Wrote metadata: {metadata_path}")
print(f"▶ Tiles written to: {run_public_dir / 'tiles'}")
print(f"▶ Metadata written to: {metadata_path}")
