import configparser
import math

import numpy as np
import psycopg2
from osgeo import gdal, osr
from GetAverageData import get_min_max
cfg = configparser.ConfigParser()
cfg.read('../config.ini')
assert "POSTGRES" in cfg, "missing POSTGRES in config.ini"
assert "INTERPOLATION" in cfg, "missing INTERPOLATION in config.ini"
param_postgres = cfg["POSTGRES"]

ramp = [[255,255,255,1],[255,244,191,1],[255,233,128,1],[255,221,64,1],[255,210,0,1],[243,105,0,1],[230,0,0,1],[153,0,0,1],[77,0,0,1],[0,0,0,1]]


def get_class_steps(colour_ramp):
    with psycopg2.connect(database=param_postgres["dbName"], user=param_postgres["user"], password=param_postgres["password"], host=param_postgres["host"], port=param_postgres["port"]) as connection:
        with connection.cursor() as cursor:
            min_max = get_min_max(cursor)
            classes = len(colour_ramp)
            temp_range = min_max[1] - min_max[0]
            steps = temp_range / classes
            min = min_max[0]
            # print(min_max)
            return min, steps, classes


def colour_picker(min, steps, classes, colour_ramp, value):
    # print(min, steps, classes, value)
    rgba = None
    for i in range(0, classes + 1):

        minor = math.floor(min + (i * steps))
        major = math.ceil(min + ((i + 1) * steps))
        # print('k:', minor, 'wert:', value, 'g:', major)
        if minor <= value <= major:

            try:
                rgba = colour_ramp[i]
            except IndexError:
                rgba = colour_ramp[-1]
            # print(i)
    # print('ramp:', rgba)
    if not rgba:
        rgba = [0, 0, 0, 0]
    return rgba


def write_raster(data):
    min, steps, classes = get_class_steps(ramp)

    pixel_array_r = []
    pixel_array_g = []
    pixel_array_b = []
    pixel_array_a = []
    for j in range(0, 36):
        row_array_r = []
        row_array_g = []
        row_array_b = []
        row_array_a = []
        for i, station_id in enumerate(data):
            if i % 36 == 0:
                value = data[i + j][1]
                value = 0 if not value else value
                value = 0 if str(value) == 'NaN' else value
                if not value == 0:
                    rgba = colour_picker(min, steps, classes, ramp, value)
                    # print(rgba)
                    r, g, b, a = rgba[0], rgba[1], rgba[2], rgba[3]
                else:
                    r, g, b, a = 0, 0, 0, 0
                # print('r', r, 'g', g, 'b', b, 'a', a)
                transparent = data[i + j][2]
                print(transparent)
                a = 0 if transparent else a
                a = 255 if a == 1 else a
                row_array_r.append(r)
                row_array_g.append(g)
                row_array_b.append(b)
                row_array_a.append(a)
                np_row_array_r = np.array(row_array_r)
                np_row_array_g = np.array(row_array_g)
                np_row_array_b = np.array(row_array_b)
                np_row_array_a = np.array(row_array_a)
        pixel_array_r.append(np_row_array_r)
        pixel_array_g.append(np_row_array_g)
        pixel_array_b.append(np_row_array_b)
        pixel_array_a.append(np_row_array_a)
    np_pixel_array_r = np.array(pixel_array_r)
    np_pixel_array_g = np.array(pixel_array_g)
    np_pixel_array_b = np.array(pixel_array_b)
    np_pixel_array_a = np.array(pixel_array_a)

    r_band = np_pixel_array_r
    g_band = np_pixel_array_g
    b_band = np_pixel_array_b
    a_band = np_pixel_array_a

    xmin, ymin, xmax, ymax = [5.01, 47.15, 14.81, 55.33]
    nrows, ncols = np.shape(r_band)
    xres = (xmax - xmin) / float(ncols)
    yres = (ymax - ymin) / float(nrows)
    geotransform = (xmin, xres, 0, ymax, 0, -yres)

    output_raster = gdal.GetDriverByName('GTiff').Create('D:/Uni/Master/01_SS2021/Automatisierte_Geodatenprozessierung/temperaturverteilung/dataacquisition/output/myraster.tif', ncols, nrows, 4, gdal.GDT_Float32)  # Open the file
    output_raster.SetGeoTransform(geotransform)
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(4326)
    output_raster.SetProjection(srs.ExportToWkt())
    output_raster.GetRasterBand(1).WriteArray(r_band)
    output_raster.GetRasterBand(2).WriteArray(g_band)
    output_raster.GetRasterBand(3).WriteArray(b_band)
    output_raster.GetRasterBand(4).WriteArray(a_band)

    output_raster.FlushCache()