"""Data Acquisition for DWD global monthly Air Temperature 
Author: Peter Morstein
"""

import pandas as pd
import geopandas as gpd
import urllib.request
from ftplib import FTP
import pickle
import numpy as np

stationURL = "https://opendata.dwd.de/climate_environment/CDC/help/stations_list_CLIMAT_data.txt"
dwdFtpServer = "opendata.dwd.de"
dwdFtpUri = "/climate_environment/CDC/observations_global/CLIMAT/monthly/qc/air_temperature_mean/historical/"
countryAnnualTemp = pd.DataFrame([])

def loadDWDGauges():
    global stationList
    # load station list from dwd
    stationList = pd.read_csv(stationURL, delimiter=";", skiprows=0, usecols=[0,2,3,5], names=["id","lon","lat","country"], header=0, encoding="ISO-8859-1 ")
    stationList = stationList.dropna(how="any", axis=0) 
    stationList['country'] = stationList['country'].str.strip()
    stationList['lon'] = stationList['lon'].str.strip()
    stationList['lat'] = stationList['lat'].str.strip()
    
    # rename countries to merge with geopandas world shape file
    stationList.loc[stationList['country']=="Korea, Dem. People's Rep.", 'country'] = 'South Korea'
    stationList.loc[stationList['country']=="Slovakia (Slovak. Rep.)", 'country'] = 'Slovakia'
    stationList.loc[stationList['country']=="Slowenia", 'country'] = 'Slovenia'
    stationList.loc[stationList['country']=="Russian Federation", 'country'] = 'Russia'
    stationList.loc[stationList['country']=="Bosnia and Herzegowina", 'country'] = 'Bosnia and Herz.'
    stationList.loc[stationList['country']=="Slovakia (Slovak. Rep.)", 'country'] = 'Slovakia'
    stationList.loc[stationList['country']=="Croatia/Hrvatska", 'country'] = 'Croatia'
    stationList.loc[stationList['country']=="Moldova, Rep. Of", 'country'] = 'Moldova'
    stationList.loc[stationList['country']=="United Kingdom of Great Britain and N.-Ireland ", 'country'] = 'United Kingdom'
    stationList.loc[stationList['country']=="Czech Republic", 'country'] = 'Czechia'
    stationList.loc[stationList['country']=="Somalia", 'country'] = 'Somaliland'
    stationList.loc[stationList['country']=="Iran (Islamic Rep. of)", 'country'] = 'Iran'
    stationList.loc[stationList['country']=="Mauretania", 'country'] = 'Mauritania'
    stationList.loc[stationList['country']=="Central African Republic", 'country'] = 'Central African Rep.'
    stationList.loc[stationList['country']=="South Sudan", 'country'] = 'S. Sudan'
    stationList.loc[stationList['country']=="Dem. Republic of the Congo", 'country'] = 'Dem. Rep. Congo'
    stationList.loc[stationList['country']=="Mauretania", 'country'] = 'Somalia'
    stationList.loc[stationList['country']=="Syrian Arab Rep.", 'country'] = 'Syria'
    stationList.loc[stationList['country']=="Australien, SW-Pazifik", 'country'] = 'Australia'
    stationList.loc[stationList['country']=="Western-Sahara",'country'] = "W. Sahara"
    
    # load climate files from dwd
    dwdFTP = FTP(dwdFtpServer)
    dwdFTP.login()
    dwdFTP.cwd(dwdFtpUri)
    
    fileList = pd.DataFrame({'id':[],"file":[]})
    ftpIds = []
    ftpFileNames = []
    for file_name in dwdFTP.nlst():
        ftpFileNames.append(file_name)
        ftpIds.append(file_name.split("_")[0])
    fileList = pd.DataFrame({'id':ftpIds,"file":ftpFileNames})
    ftpIds.clear()
    ftpFileNames.clear()
    
    dwdFTP.quit()
    
    # filter climate files list by longest timeseries 
    # (because: there are multiple timeseries-files per station with same historical values)
    longestSeries = pd.DataFrame()
    for index, ftpFiles in fileList.groupby("id", axis=0):
        longestSeries = longestSeries.append(ftpFiles.iloc[-1])
    fileList.drop(fileList.index, inplace=True)
    
    # concat climate files with station list
    stationList = stationList.set_index("id").join(longestSeries.set_index("id"), on="id")
    stationList = stationList.dropna(axis=0, how="any")
    stationList = stationList[stationList.country!=""]
    
    # with open("stationList.pickle","wb") as pf:
    #      pickle.dump(stationList, pf)


def fillMissingData(annualData):
    months = ["Jan", "Feb", "Mrz","Apr","Mai","Jun","Jul","Aug","Sep","Okt","Nov","Dez"]
    
    for y in range(0,len(annualData)):
        
        # check month for nan values
        for m in range(0,len(months)):
            #print(annualData.iloc[y].loc[months[m]])
            if np.isnan(annualData.iloc[y].loc[months[m]]):
                
                prevYear = None
                nextYear = None
                prevMonth = m-1
                nextMonth = m+1
                
                if y >= 1:
                    prevYear = y-1
                if y < len(annualData)-1:
                    nextYear = y+1
                
                averageList = []
                if prevYear != None:
                    averageList.append(annualData.iloc[prevYear].loc[months[m]])
                
                if nextYear != None:
                    averageList.append(annualData.iloc[nextYear].loc[months[m]])
                
                if prevMonth >= 0:
                    averageList.append(annualData.iloc[y].loc[months[prevMonth]])
                
                if prevMonth < 0 and prevYear != None:
                     prevMonth = len(months)-1
                     averageList.append(annualData.iloc[prevYear].loc[months[prevMonth]])
                
                if nextMonth < len(months):
                    averageList.append(annualData.iloc[y].loc[months[nextMonth]])
                
                if nextMonth >= len(months) and nextYear!=None:
                     nextMonth = 0
                     averageList.append(annualData.iloc[nextYear].loc[months[nextMonth]])
                                        
                annualData.iat[y,m] = np.round(np.nanmean(averageList),2)

    annualData["mean"] = np.round(annualData.iloc[:,0:11].mean(axis=1,skipna=True),2)
    
    return annualData


def loadTemperatureFromDWDGauges():
    global climateCountry
    global stationList
    global annualData
    global worldTemperature
        
    with open("./pickle/stationList.pickle", "rb") as pickleFile:
        stationList = pickle.load(pickleFile)
    #stationList  = pd.concat([stationList, pd.DataFrame(columns=list(range(1950,2020)))])
    
    for index, gaugeCountry in stationList.groupby("country", axis=0):
        
        print(index,": ",len(gaugeCountry.country))
        gaugeURLs = "https://"+dwdFtpServer+dwdFtpUri+gaugeCountry.file
        gaugeIds = gaugeCountry.index
        
        for gid, gurl in zip(gaugeIds, gaugeURLs):
            annualData = pd.read_csv(gurl, delimiter=";")
            annualData = annualData.set_index("Jahr")
            annualData["mean"] = annualData.mean(axis=1)
            #annualData = fillMissingData(annualData)
            
            for dataIndex, annualMean in annualData.iterrows():
                try:
                    stationList.at[gid, dataIndex] = annualMean["mean"]
                except:
                    continue
                
    with open("./pickle/stationList_temps_missingData.pickle", "wb") as pickleFile:
        pickle.dump(stationList, pickleFile)
            
        
def buildAverageTimeseries(fromYear, toYear, name):
    global stationList
    
    meanAverage = []
    for stationID, station in stationList.iterrows():
        temps = []
        for i in range(fromYear,toYear):
            if not np.isnan(station[i]):
                temps.append(station[i])
        
        if len(temps) > 5:
            meanAverage.append(np.mean(temps))
        else:
            meanAverage.append(np.NaN)
    
    stationList[name] = np.round(meanAverage,1)
    
def cleanAverageTimeseries():
    # determine gauges that includes both timeseries. If not delete them.
    global stationList
    
    for stationID, station in stationList.iterrows():
        if np.isnan(station['m1961T1990']) or np.isnan(station['m1991T2018']):
            #station['m1961T1990'] = None
            #station['m2010T2018'] = None
            stationList.at[stationID, "m1961T1990"] = None
            stationList.at[stationID, "m1991T2018"] = None
            
    
    
    
def buildStationListGDP():
    global stationList
    global stationGPD
    
    with open("./pickle/stationList_temps_missingData.pickle", "rb") as pickleFile:
        stationList = pickle.load(pickleFile)
    
    del stationList["file"]
    
    stationList = stationList[stationList.lat != ""]
    stationList = stationList[stationList.lon != ""]
    
    stationList["lat"] = stationList["lat"].astype(str).astype(float)
    stationList["lon"] = stationList["lon"].astype(str).astype(float)
    
    buildAverageTimeseries(1961,1990,"m1961T1990")
    buildAverageTimeseries(1991,2018,"m1991T2018")
    cleanAverageTimeseries()
    
    stationGPD = gpd.GeoDataFrame(stationList, geometry=gpd.points_from_xy(stationList.lat, stationList.lon)).reset_index()
    del stationGPD["lat"]
    del stationGPD["lon"]
    
    stationGPD.columns = stationGPD.columns.astype(str)
    stationGPD = stationGPD.sort_index(axis=1, ascending=False)
    #stationGPD.to_file("stationList.shp", "ESRI Shapefile")
    
    
def buildAnnualCountryTemp():
    global stationList
    global countryTemp
    global countryAnnualTemp
    
    world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
    
    countryMeanGPD = gpd.sjoin(world, stationGPD, how="inner", op='intersects')
    countryMeanGPD = countryMeanGPD.groupby("name", axis=0).mean().reset_index()
    
    countryMeanGPD["anomalie"] = countryMeanGPD["m1991T2018"] - countryMeanGPD["m1961T1990"]
    
    del countryMeanGPD["pop_est"]
    del countryMeanGPD["gdp_md_est"]
    
    for i in range(1873,1950):
        del countryMeanGPD[str(i)]
    
    worldGauge = world.set_index("name").join(countryMeanGPD.set_index("name"))
    #worldGauge = worldGauge.loc[(worldGauge.continent=="Europe")]
    #worldGauge.columns = countryAnnualTemp.columns.map(str)
    
    worldGauge.to_file("./output/countryAnnualTemperature.shp", "ESRI Shapefile")
    
    #countryAnnualTemp = stationList.groupby("country", axis=0).mean().reset_index()
    #countryAnnualTemp.columns = countryAnnualTemp.columns.map(str)
    #countryAnnualTemp = countryAnnualTemp.sort_index(axis=1, ascending=False)
    
    #break


print("___ DWD Acquisition start___")
#loadDWDGauges()
#loadTemperatureFromDWDGauges()
buildStationListGDP()
buildAnnualCountryTemp()

print("___DWD Acquisition finished___")