diff --git a/api/GetAverageData.py b/api/GetAverageData.py
index 10d642708878b6479b73d2a9a1d2e9c8e7e94e7d..0a8542d1b09e56947a6445c74d76e57a73903bb6 100644
--- a/api/GetAverageData.py
+++ b/api/GetAverageData.py
@@ -61,6 +61,13 @@ def find_extremum(cursor, years, extremum):
     return result
 
 
+def get_min_max(cursor):
+    years = get_year_columns(cursor)
+    min = find_extremum(cursor, years, 'MIN')
+    max = find_extremum(cursor, years, 'MAX')
+    return min, max
+
+
 # Find n (defined in config) neighbours and return them ordered by distance
 def get_neighbours(cursor, lat, lon, columns):
     values = ''  # Used in second parameter of cursor.execute() (Avoids SQL injection)
diff --git a/api/api.py b/api/api.py
index 11a84d28a5bf9a516daa6852cc6402117e73cf37..399060416505d256d5ec092f64f32400463c0422 100644
--- a/api/api.py
+++ b/api/api.py
@@ -3,8 +3,8 @@ from flask import Flask, jsonify, request
 from flask_cors import cross_origin
 from psycopg2 import sql
 import configparser
-from GetAverageData import get_interpolation_data_for_point, get_year_columns, get_average_of_multiple_years, find_extremum
-from api import write_raster
+from GetAverageData import get_interpolation_data_for_point, get_year_columns, get_average_of_multiple_years, get_min_max
+from write_raster import write_raster
 import SQLPandasTools as s2pTool
 
 cfg = configparser.ConfigParser()
@@ -97,13 +97,10 @@ def getStandardQuery():
 
 
 @app.route('/minmax', methods=['GET'])
-def get_min_max():
+def return_min_max():
     with psycopg2.connect(database=param_postgres["dbName"], user=param_postgres["user"], password=param_postgres["password"], host=param_postgres["host"], port=param_postgres["port"]) as connection:
         with connection.cursor() as cursor:
-            years = get_year_columns(cursor)
-            min = find_extremum(cursor, years, 'MIN')
-            max = find_extremum(cursor, years, 'MAX')
-            print(min, max)
+            min, max = get_min_max(cursor)
             return {'min': str(min), 'max': str(max)}
 
 
@@ -117,7 +114,7 @@ def get_raster():
             with connection.cursor() as cursor:
                 # all_years = get_year_columns(cursor)
                 average_data = get_average_of_multiple_years(cursor, years)
-                write_raster.write_raster(average_data)
+                write_raster(average_data)
         # return send_from_directory('D:/Uni/Master/01_SS2021/Automatisierte_Geodatenprozessierung/temperaturverteilung/dataacquisition/output', filename='myraster.tif', as_attachment=True)
         return 'Läuft, Brudi'
 
diff --git a/api/write_raster.py b/api/write_raster.py
index c8d62ca67c9fb2146affd7dc7664dfc76fd4360b..99770a2d911b91ec929ec0d441508c523e62bb04 100644
--- a/api/write_raster.py
+++ b/api/write_raster.py
@@ -1,31 +1,104 @@
+import configparser
+import math
+
 import numpy as np
+import psycopg2
 from osgeo import gdal, osr
+from GetAverageData import get_min_max
+cfg = configparser.ConfigParser()
+cfg.read('../config.ini')
+assert "POSTGRES" in cfg, "missing POSTGRES in config.ini"
+assert "INTERPOLATION" in cfg, "missing INTERPOLATION in config.ini"
+param_postgres = cfg["POSTGRES"]
+
+ramp = [[255,255,255,1],[255,244,191,1],[255,233,128,1],[255,221,64,1],[255,210,0,1],[243,105,0,1],[230,0,0,1],[153,0,0,1],[77,0,0,1],[0,0,0,1]]
+
+
+def get_class_steps(colour_ramp):
+    with psycopg2.connect(database=param_postgres["dbName"], user=param_postgres["user"], password=param_postgres["password"], host=param_postgres["host"], port=param_postgres["port"]) as connection:
+        with connection.cursor() as cursor:
+            min_max = get_min_max(cursor)
+            classes = len(colour_ramp)
+            temp_range = min_max[1] - min_max[0]
+            steps = temp_range / classes
+            min = min_max[0]
+            # print(min_max)
+            return min, steps, classes
+
+
+def colour_picker(min, steps, classes, colour_ramp, value):
+    # print(min, steps, classes, value)
+    rgba = None
+    for i in range(0, classes + 1):
+
+        minor = math.floor(min + (i * steps))
+        major = math.ceil(min + ((i + 1) * steps))
+        # print('k:', minor, 'wert:', value, 'g:', major)
+        if minor <= value <= major:
+
+            try:
+                rgba = colour_ramp[i]
+            except IndexError:
+                rgba = colour_ramp[-1]
+            # print(i)
+    # print('ramp:', rgba)
+    if not rgba:
+        rgba = [0, 0, 0, 0]
+    return rgba
 
 
 def write_raster(data):
-    pixel_array = []
-    pixel_array_alpha = []
+    min, steps, classes = get_class_steps(ramp)
+
+    pixel_array_r = []
+    pixel_array_g = []
+    pixel_array_b = []
+    pixel_array_a = []
     for j in range(0, 36):
-        row_array = []
-        row_array_alpha = []
+        row_array_r = []
+        row_array_g = []
+        row_array_b = []
+        row_array_a = []
         for i, station_id in enumerate(data):
             if i % 36 == 0:
                 value = data[i + j][1]
                 value = 0 if not value else value
                 value = 0 if str(value) == 'NaN' else value
-                alpha_value = data[i + j][2]
-                alpha_value = 0 if alpha_value else 255
-                row_array.append(value)
-                row_array_alpha.append(alpha_value)
-                np_row_array = np.array(row_array)
-                np_row_array_alpha = np.array(row_array_alpha)
-        pixel_array.append(np_row_array)
-        pixel_array_alpha.append(np_row_array_alpha)
-    np_pixel_array = np.array(pixel_array)
-    np_pixel_array_alpha = np.array(pixel_array_alpha)
+                if not value == 0:
+                    rgba = colour_picker(min, steps, classes, ramp, value)
+                    # print(rgba)
+                    r, g, b, a = rgba[0], rgba[1], rgba[2], rgba[3]
+                else:
+                    r, g, b, a = 0, 0, 0, 0
+                # print('r', r, 'g', g, 'b', b, 'a', a)
+                transparent = data[i + j][2]
+                print(transparent)
+                a = 0 if transparent else a
+                a = 255 if a == 1 else a
+                row_array_r.append(r)
+                row_array_g.append(g)
+                row_array_b.append(b)
+                row_array_a.append(a)
+                np_row_array_r = np.array(row_array_r)
+                np_row_array_g = np.array(row_array_g)
+                np_row_array_b = np.array(row_array_b)
+                np_row_array_a = np.array(row_array_a)
+        pixel_array_r.append(np_row_array_r)
+        pixel_array_g.append(np_row_array_g)
+        pixel_array_b.append(np_row_array_b)
+        pixel_array_a.append(np_row_array_a)
+    np_pixel_array_r = np.array(pixel_array_r)
+    np_pixel_array_g = np.array(pixel_array_g)
+    np_pixel_array_b = np.array(pixel_array_b)
+    np_pixel_array_a = np.array(pixel_array_a)
+
+    r_band = np_pixel_array_r
+    g_band = np_pixel_array_g
+    b_band = np_pixel_array_b
+    a_band = np_pixel_array_a
 
     xmin, ymin, xmax, ymax = [5.01, 47.15, 14.81, 55.33]
-    nrows, ncols = np.shape(np_pixel_array)
+    nrows, ncols = np.shape(r_band)
     xres = (xmax - xmin) / float(ncols)
     yres = (ymax - ymin) / float(nrows)
     geotransform = (xmin, xres, 0, ymax, 0, -yres)
@@ -35,9 +108,11 @@ def write_raster(data):
     srs = osr.SpatialReference()
     srs.ImportFromEPSG(4326)
     output_raster.SetProjection(srs.ExportToWkt())
-    output_raster.GetRasterBand(1).WriteArray(np_pixel_array)
-    # output_raster.GetRasterBand(2).WriteArray(np_pixel_array)
-    # output_raster.GetRasterBand(3).WriteArray(np_pixel_array)
-    output_raster.GetRasterBand(4).WriteArray(np_pixel_array_alpha)
+    output_raster.GetRasterBand(1).WriteArray(r_band)
+    output_raster.GetRasterBand(2).WriteArray(g_band)
+    output_raster.GetRasterBand(3).WriteArray(b_band)
+    output_raster.GetRasterBand(4).WriteArray(a_band)
 
     output_raster.FlushCache()
+
+