#!/usr/bin/env python3
"""
@author: Juraj Jonak
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from scipy.interpolate import griddata
import os


class DynamicSpectrum:
    def __init__(self, outname='dynamic_spectrum.png'):
        self.outname = outname
        self.dates = np.empty(1)
        self.files = np.empty(1)
        self.init_colormaps()
        self.set_defaults()

    def add_LST(self, name, pref, folder="./", col=-2, headerrows=8): # reads a LST file of observations, col is the column of observation time
        spectrafiles, hjds = self.read_lst(name, pref, folder, col, headerrows)
        self.files = np.concatenate([self.files, spectrafiles])
        self.dates  = np.concatenate([self.dates, hjds])

    def set_toRV(self, value=True): # set x to equidistant in RV
        self.to_RV = value

    def set_wlmid(self, value): # set central wavelength
        self.wlmid = value

    def set_xvalrange(self, value): # sets interval around central line
        self.xvalrange = value

    def set_xvalstep(self, value): # set step in x to interpolate the spectra to
        self.xvalstep = value

    def set_nrlevels(self, value): # set number of levels
        self.nrlevels = value

    def set_period(self, value, digits=-1): # set value of orbital period
        self.period = value
        self.perioddigits = digits

    def set_T0(self, value, digits=-1): # set time for which phase equals 0
        self.T0 = value if value < 2.4e+6 else value-2.4e+06
        self.t0digits = digits

    def set_cmap(self, value): # set a particular colormap
        self.cmap = value
        try:
            self.colormaps[self.cmap]=="Diverging"
        except KeyError:
            print("{} is not implemented. You may add a dictionary using colormaps.update(\{{}\}: 'Sequential' or 'Diverging'".format(value, value))

    def add_cmap(self, cmap, cmclass):
        self.colormaps.update({cmap:cmclass})
        self.set_cmap(cmap)

    def set_title(self, value): # set title - use raw string for greek letters
        self.title = value

    def set_emptyColor(self, value): # set the color of empty space
        self.emptyColor = value

    def set_zeroPhaseColor(self, value): # mark phase = 0
        self.zeroPhaseColor = value

    def set_midpoint(self, value): # for diverging colormap it is possible to set a different midpoint
        self.midpoint = value

    def set_defaults(self): # default values around H alpha
        self.set_toRV(False)
        self.set_wlmid(6562.81)
        self.set_xvalrange(10)
        self.set_xvalstep(.1)
        self.set_nrlevels(25)
        self.set_period(False)
        self.set_T0(0.)
        self.set_cmap("Reds")
        self.set_emptyColor("k")
        self.set_zeroPhaseColor("k")
        self.set_midpoint(1)
        self.set_title(None)

    def init_colormaps(self): # initial implemented colormaps
        self.colormaps = dict([("Reds",  "Sequential"),("Reds_r",  "Sequential"),
                               ("Blues", "Sequential"),("Blues_r", "Sequential"),
                               ("Greys", "Sequential"),("Greys_r", "Sequential"),
                               ("RdBu",  "Diverging"),("RdBu_r",  "Diverging")])

    def order(self): # orders the observations in time/or phase if period is given
        if self.period:
            times = self.calc_phase(self.dates[1:])
        else:
            times = self.dates[1:]
        files = self.files[1:]
        self.files = files[times.argsort()]
        self.dates = np.sort(times)

    def get_fluxes(self): # individual spectral data are stored in one large 2D array
        self.xvals = np.arange(-self.xvalrange, self.xvalrange+self.xvalstep, self.xvalstep)
        if not self.to_RV:
            self.xvals += self.wlmid
        fluxes = np.empty((np.size(self.dates), np.size(self.xvals)))
        for i, fname in zip(np.arange(np.size(self.files)), self.files):
            fluxes[i] = self.read_spectrafile(fname)
        self.fluxes = fluxes

    def read_spectrafile(self, name): # reading of one .asc spectrum file and interpolating to equidistant velocities
        wls, fluxes = np.loadtxt(name, usecols=[0,1], skiprows=1, unpack=True)
        if self.to_RV:
            wls_interp = self.calc_doppler(self.wlmid, self.xvals)
        else:
            wls_interp = self.xvals
        wlmin, wlmax = wls_interp[0], wls_interp[-1]
        fluxes_interp = griddata(wls, fluxes, wls_interp, method='cubic')
        return(fluxes_interp)

    @staticmethod
    def read_lst(lst, pref, foldername, col, headerrows): # read one LST file
        lstpath = os.path.join(foldername, lst)
        hjdlist = np.loadtxt(lstpath, skiprows=headerrows, usecols=col)
        hjdlist[hjdlist>2.e+06]-=2.4e+06
        namelist = np.append(np.empty(1), np.loadtxt(lstpath, skiprows=headerrows, usecols=0))
        spec = []
        for it in namelist[1:]:
            try:
                with open(os.path.join(foldername,'{}{:05.0f}.asc'.format(pref,it)), "r") as inp:
                    fpath = os.path.join(foldername,'{}{:05.0f}.asc'.format(pref,it))
            except IOError:
                    fpath = os.path.join(foldername,'{}{:05.0f}.ASC'.format(pref,it))
            spec.append(fpath)
        return(spec, hjdlist)

    def calc_phase(self, dates): # phase calculation function
        phase = ((dates - self.T0) % self.period) / self.period
        return(phase)

    @staticmethod
    def calc_doppler(wlmid, vels): # calculation of wavelength shift from radial velocity
        c = 299792.458
        wls = wlmid * (vels / c + 1.)
        return(wls)

    def plot_ds(self): # dynamic spectrum plot
        self.prepare_ds()
        x, y, data = self.xvals, self.dates, self.fluxes
        if self.period: # if phased spectrum duplicate the observations
            y = np.concatenate((y, np.ones(1)))
            data = np.append(data, data[0])
            y = np.concatenate((y-1., y))
            data = np.concatenate((data, data), axis=0).reshape(np.size(y), np.size(x))
        diff = y[1:]-y[:-1]
        datamin, datamax = np.min(data), np.max(data)
        y_append = y[np.where(diff>5*np.median(diff))]+5*np.median(diff) # if the difference between two observations is larger than 5 times the median creates an empty space
        data_append = (np.full(np.size(x) * np.size(y_append), np.nan)).reshape(np.size(y_append), np.size(x))
        y = np.concatenate((y, y_append))
        data = np.concatenate((data, data_append), axis=0)
        data = data[y.argsort()]
        y = np.sort(y)
        X,Y = np.meshgrid(x, y)
        fig, ax = plt.subplots()
        cmap = cm.get_cmap(self.cmap, lut=self.nrlevels)
        cmap.set_bad(self.emptyColor)
        if self.colormaps[self.cmap]=="Diverging":
            cnorm = mcolors.TwoSlopeNorm(vmin=min(datamin,self.midpoint-1.e-3), vcenter=self.midpoint, vmax=max(datamax, self.midpoint+1.e-3))
            spectrum = ax.pcolormesh(X, Y, data, norm=cnorm, cmap=cmap, shading='nearest')
        else:
            spectrum = ax.pcolormesh(X, Y, data, cmap=cmap, shading='nearest')
        if self.period:
            ax.plot([self.xvals[0],self.xvals[-1]], [0,0], linewidth=1, ls='-', color=self.zeroPhaseColor)
        ax.invert_yaxis()
        ax.set_xlabel('v [km/s]') if self.to_RV else ax.set_xlabel(r"$\lambda$ [$\AA$]")
        ax.set_ylabel("RJD") if np.max(y)>1. else ax.set_ylabel("Phase")
        if self.period:
            ax.set_ylim([1.,-1.])
        if self.perioddigits>0 and self.t0digits>0:
            ax.text(self.xvals[-1], -1., 'Period: {}\nHJD$_0$: {}'.format(round(self.period, self.perioddigits), round(self.T0, self.t0digits)),
                     fontsize=10, verticalalignment='bottom')
        fig.colorbar(spectrum,ax=ax)
        if self.title:
            fig.suptitle(self.title)
        fig.savefig(self.outname)
        fig.clf()

    def prepare_ds(self): # call the functions to order the observations and read files
        self.order()
        self.get_fluxes()

    def get_files(self): # returns names of files to be read - used for debug
        return(self.files)

    def get_dates(self): # return dates of observations - used for debug
        return(self.dates)

    def do_debug(self): # for each observation print its name, date, maximal and minimal value - look out for any NaN values
        self.prepare_ds()
        for it in zip(self.get_files(), self.get_dates(), [np.max(x) for x in self.fluxes], [np.min(x) for x in self.fluxes]):
            print(it)




