import configparser
import numpy as np
from psycopg2 import sql
import numpy as np

cfg = configparser.ConfigParser()
cfg.read('../config.ini')
assert "POSTGRES" in cfg, "missing POSTGRES in config.ini"
assert "INTERPOLATION" in cfg, "missing INTERPOLATION in config.ini"
param_postgres = cfg["POSTGRES"]
param_interpol = cfg["INTERPOLATION"]


def get_average_of_multiple_years(cursor, years):
    avg_strings = " "
    where_sql = " WHERE lat IS NOT NULL "
    and_strings = ""
    n = int(years[1]) - int(years[0])
    for year in range(int(years[0]), int(years[1])+1):
        avg_string = ' AVG ("{}") + '.format(str(year))
        and_string = """ AND "{}" != 'NaN' """.format(str(year))

        avg_strings += avg_string
        and_strings += and_string
    avg_strings = avg_strings[:-2]

    query = """SELECT station_id, ROUND(({}) / {}, 1), transparent FROM stations WHERE file IS NULL GROUP BY station_id, transparent ORDER BY station_id ASC;""".format(avg_strings, n)
    print(query)
    cursor.execute(query)
    return cursor.fetchall()


# Getting all available year columns from database
def get_year_columns(cursor):
    columns = []
    query = sql.SQL("SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'stations';")
    cursor.execute(query)
    results = cursor.fetchall()
    for result in results:
        try:
            columns.append(int(result[0]))
        except ValueError:
            pass
    return columns


# Find n (defined in config) neighbours and return them ordered by distance
def get_neighbours(cursor, lat, lon, columns):
    values = ''  # Used in second parameter of cursor.execute() (Avoids SQL injection)
    for n in [lat, lon]:
        values = (*values, n)  # adding n to existing tuple

    query = sql.SQL("""
                SELECT array_to_json(array_agg(row_to_json(t))) from (
                    SELECT {columns}, ST_Distance(ST_MakePoint(lat, lon), ST_MakePoint({lon}, {lat})) AS distance 
                    FROM stations 
                    WHERE file IS NOT NULL 
                    ORDER BY distance 
                    LIMIT {amount_neighbours}
                ) t;
                    """).format(columns=columns, lon=sql.Placeholder(), lat=sql.Placeholder(), amount_neighbours=sql.SQL(param_interpol["amount_neighbours"]))
    cursor.execute(query, values)
    neighbours = cursor.fetchall()[0][0]
    return neighbours


# Deprecated and unused. Calculating interpolation data just by average. Insufficient statistical method for this use case
def calc_averages(neighbours, years):
    averages = {}
    for year in years:
        values = []
        for neighbour in neighbours:
            # print(neighbour[str(year)])
            if not neighbour[str(year)] == 'NaN': values.append(neighbour[str(year)])
        avg = round(sum(values) / len(values), 3)
        averages[year] = avg
    return averages


# Calculating interpolation data by Inverse Distance Weighted method. Values are decreasingly important with increasing distance
def calc_idw(neighbours, years):
    weighted_values = {}
    for year in years:
        values = []
        distances = []
        for neighbour in neighbours:
            distances.append(neighbour['distance'])
        for neighbour in neighbours:
            normalizer = float(param_interpol["amount_neighbours"]) / sum(distances)
            weight = neighbour['distance'] * normalizer
            if not neighbour[str(year)] == 'NaN': values.append(neighbour[str(year)] * weight)
        try:
            avg = round(sum(values) / len(values), 3)
            weighted_values[year] = avg
        except ZeroDivisionError:
            # print('No Data (NaN in DB)')
            pass
    return weighted_values


# Collecting preparation data and execute interpolation
def get_interpolation_data_for_point(lat, lon, columns, cursor):
    if '*' in str(columns):
        year_columns = get_year_columns(cursor)
    else:
        year_columns = (str(columns).replace("""SQL('""", "").replace('"', '').replace("')", "")).split(',')
    neighbours = get_neighbours(cursor, lat, lon, columns)
    avg_data = calc_idw(neighbours, year_columns)
    # print(avg_data)
    return avg_data


# get_average_data_for_point(52.5, 13.4)

def calcAverageYear(stationList, fromYear, toYear):
    
    dateRange = np.arange(fromYear, toYear+1)
    dateRangeRegex = "|".join(np.char.mod('%d', dateRange))
    
    stationListDate = stationList.filter(regex=dateRangeRegex)
    
    stationList["anomalie"] = stationListDate.mean(axis=1)
    stationList = stationList.dropna(axis=0, subset=['anomalie'])
    
    return stationList