# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 15:26:34 2021

@author: georg
"""



import numpy as np
from matplotlib import pyplot as plt
from scipy.linalg import expm    




    
    
###############################################################################
###############################################################################
###############################################################################



def cumulantsField0(liouvillain, paramCons, paramVar ,dChi):
    
    
    paraTmp = paramVar+0
    chi_ar= np.array( [-dChi,0] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMin = eigV[nmb]
    

    lambdaNull = 0
    

    chi_ar= np.array( [dChi,0] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMax = eigV[nmb]
    
    
    #--------------------------------------------------------------------------
    rabi0Del = paramVar[0]*0.001
    rabi1Del = paramVar[1]*0.001
    
    
    lambdaNullDel0  = 0
    
    
    paraTmp = paramVar+0
    paraTmp[0] +=  rabi0Del
    chi_ar= np.array( [dChi,0] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMaxDel0 = eigV[nmb]+0
    
    lambdaNullDel1  = 0
    
    
    paraTmp = paramVar+0
    paraTmp[1] +=  rabi1Del
    chi_ar= np.array( [dChi,0] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMaxDel1  = eigV[nmb]+0
    
    
    #--------------------------------------------------------------------------
    
    cur =  -np.real( (lambdaMax-lambdaNull ) /dChi/1j)
    var = np.real( - (lambdaMax+lambdaMin -2*lambdaNull ) /dChi/dChi )
    curDel0 = - np.real( (lambdaMaxDel0-lambdaNullDel0 ) /dChi/1j)
    curDel1 = - np.real( (lambdaMaxDel1-lambdaNullDel1 ) /dChi/1j)
    curDer  = np.array( [ (curDel0-cur)/ rabi0Del  , (curDel1-cur  )/ rabi1Del   ] ) 

    
    return cur, var  , curDer
    
    

def cumulantsField1(liouvillain, paramCons, paramVar ,dChi):
    
    paraTmp = paramVar+0
    chi_ar= np.array( [0,-dChi] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMin = eigV[nmb]
    

    lambdaNull = 0
    
    paraTmp = paramVar+0
    chi_ar= np.array( [0,dChi] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMax = eigV[nmb]
    
    
    #--------------------------------------------------------------------------
    
    rabi0Del = paramVar[0]*0.001
    rabi1Del = paramVar[1]*0.001
    

    lambdaNullDel0  = 0
 
    
    paraTmp = paramVar+0
    paraTmp[0] +=  rabi0Del
    chi_ar= np.array( [0,dChi] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMaxDel0 = eigV[nmb]+0

    
    lambdaNullDel1  = 0
    

    paraTmp = paramVar+0
    paraTmp[1] +=  rabi1Del
    chi_ar= np.array( [0,dChi] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambdaMaxDel1  = eigV[nmb]+0
    
    
    #--------------------------------------------------------------------------
    
    cur =  -np.real( (lambdaMax-lambdaNull ) /dChi/1j)
    var = np.real( - (lambdaMax+lambdaMin -2*lambdaNull ) /dChi/dChi )
    curDel0 = - np.real( (lambdaMaxDel0-lambdaNullDel0 ) /dChi/1j)
    curDel1 =  -np.real( (lambdaMaxDel1-lambdaNullDel1 ) /dChi/1j)
    curDer  = np.array( [ (curDel0-cur)/ rabi0Del  , (curDel1-cur  )/ rabi1Del   ] ) 
    
    
    
    return cur, var  , curDer




def cumulantsField01(liouvillain, paramCons, paramVar ,dChi):
    

    lambda00 = 0
    
    paraTmp = paramVar+0
    chi_ar= np.array( [dChi,0] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambda01 = eigV[nmb]+0
    
    paraTmp = paramVar+0
    chi_ar= np.array( [0,dChi] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambda10 = eigV[nmb]+0
    
    paraTmp = paramVar+0
    chi_ar= np.array( [dChi,dChi] )
    liu = liouvillain(paramCons,paraTmp, chi_ar)
    eigV , _ = np.linalg.eig(liu)
    nmb = np.argmax(np.real( eigV) )
    lambda11 = eigV[nmb]+0
    
    
    cor = - np.real(  lambda11 - lambda10 -lambda01 + lambda00) /(dChi*dChi)
    
    return cor
    
    
    
def transportCoefficients(liouvillian, paramCons, paramVar ,dChi):
    
    
    cur0, var0  , _  = cumulantsField0(liouvillian, paramCons, paramVar ,dChi)
    cur1, var1  , _  = cumulantsField1(liouvillian, paramCons, paramVar ,dChi)
    cor = cumulantsField01(liouvillian, paramCons, paramVar ,dChi) 
    

    return cur0 ,cur1 , var0 , cor, cor, var1
    
    
    
###############################################################################
###############################################################################
###############################################################################
#Flow equations



class FlowEquations():
    
    def __init__(self, liouvillain, paramCons, rabiFreqs, n_ar, sigmaQ_ar , rhoAtom ,dChi ):
        
        self.dChi = dChi
        
        
        self.mean_ar = np.array( [np.sum(n_ar) , 0 ])
        self.sigmaQ_ar = sigmaQ_ar+0

        self.liouvillain =liouvillain
        self.paramCons = paramCons+0
        self.rabiFreqs = rabiFreqs+0
        
        
        self.g0 = self.rabiFreqs[0]/np.sqrt( n_ar[0])
        self.g1 = self.rabiFreqs[1]/np.sqrt( n_ar[1])
        self.rhoAtom  = rhoAtom
        
        
    def updateState(self, dzEff):
        
        
        
        cur0, var0,cur0Der = cumulantsField0(self.liouvillain, self.paramCons, self.rabiFreqs ,self.dChi) 
        cur1, var1,cur1Der = cumulantsField1(self.liouvillain, self.paramCons, self.rabiFreqs ,self.dChi) 
        cor        = cumulantsField01(self.liouvillain, self.paramCons, self.rabiFreqs ,self.dChi) 
        dSigmaQ = np.array(  [[var0, cor],[cor,var1]] )#*0!!! just for testing
        
        
        cur0_d0 = cur0Der[0]*self.g0/ np.sqrt( self.mean_ar[0]/2)
        cur0_d1 = cur0Der[1]*self.g1/ np.sqrt( self.mean_ar[0]/2)
        cur1_d0 = cur1Der[0]*self.g0/ np.sqrt( self.mean_ar[0]/2)
        cur1_d1 = cur1Der[1]*self.g1/ np.sqrt( self.mean_ar[0]/2)
        dCurD = np.array(  [[cur0_d0 , cur0_d1],[cur1_d0,cur1_d1 ]]   )/2
        dCurR = np.array([[1,1],[-1,-1]])*(cur1-cur0)/2/ self.mean_ar[0]  
        dCur=  dCurD + dCurR
        

        self.mean_ar[1] += (cur0-cur1)/2/ self.mean_ar[0]  *dzEff*self.rhoAtom
        self.mean_ar[0] += (cur0+cur1)*dzEff*self.rhoAtom
               
        if self.mean_ar[0]<0:
            self.mean_ar[0] =1e-200
        
        
        expdCur = expm(dCur*dzEff*self.rhoAtom)
        self.sigmaQ_ar = np.dot( expdCur  , np.dot(  self.sigmaQ_ar , expdCur.T ))
        self.sigmaQ_ar = 0.5*(self.sigmaQ_ar+self.sigmaQ_ar.T  )
        self.sigmaQ_ar += dzEff*self.rhoAtom*dSigmaQ
        
        if np.linalg.det(dSigmaQ) >0 and False:
            self.sigmaQ_ar += dSigmaQ*dzEff*self.rhoAtom
                  

        
        self.rabiFreqs[0] = self.g0 * np.sqrt(self.mean_ar[0]/2)
        self.rabiFreqs[1] = self.g1 * np.sqrt(self.mean_ar[0]/2)
        
        
    
    def get_state(self):
        
        return self.mean_ar   , self.sigmaQ_ar
    
    
    def probability_distribution(self,n1Min,n1Max,n2Min,n2Max,nNmb):
        
        nAr = np.zeros((2,nNmb,nNmb)  )
        n1 = np.linspace(n1Min,n1Max,nNmb)-self.n_ar[0]
        n2 = np.linspace(n2Min,n2Max,nNmb)-self.n_ar[1]
        auxAr = np.ones((nNmb,) )
        
        nAr[0] = np.tensordot(n1,auxAr  , axes=0)
        nAr[1] = np.tensordot(auxAr , n2, axes=0)
        
        sig = np.invert(self.sigmaQ_ar)
        prob = np.tensordot( sig ,  nAr , axes = ([1],[0]) )
        prob =  np.tensordot( nAr , prob , axes = ([0],[0]) )
        prob = np.exp(-prob/2)
        
        return prob/ np.sum(prob)




def fisherInformation(n0,sigmaQ0,n1,sigmaQ1,dLambda):
    

    dN = (n1-n0)/dLambda
    #dSigQ = (sigmaQ1 - sigmaQ0)/dLambda

    
    det0 = np.linalg.det(sigmaQ0 )
    det1 = np.linalg.det(sigmaQ1 )
    
    if det0>0 and det1>0:
        pass
    else:
        print('Warning: Sigma not postive definite')
    
    
    fI = np.dot(dN, np.dot(np.linalg.inv(sigmaQ0), dN) ) 

    dNPl = (dN[0] + dN[1])/2* np.array([1,1])
    dNMi = (dN[0] - dN[1])/2* np.array([1,-1])
    
    fPl = np.dot(dNPl, np.dot(np.linalg.inv(sigmaQ0), dNPl) ) 
    fMi = np.dot(dNMi, np.dot(np.linalg.inv(sigmaQ0), dNMi) ) 
    
    if det0<0:
        return -1

    
    return fI , fPl, fMi



def flowEquations_signalNoise(liouvillian, paramConsIn, rabiFreqs ,
                                 n_ar, sigmaQ_ar,pulseTau,
                                 rhoAtom, areaLaser,
                                 zMax,zNmb,dz
                                  ,idxFish ,dChi, showProgress=True):

    #Remarks for the input parameters
    #Liouvillian is a function with three arguments:
    #(1) paramCons (all in Mhz), fix parameters of the atoms, molecules, 
    ###e.g. level energies
    #(2)rabiFreqs = 2 rabi frequencies of the measured (probe) laser. When quantum
    ### emitters are isotrope, both frequencies are equal
    #(3) vector of counting fields of length 2
    
    
    paramCons =paramConsIn+0
    
    flowEq0 = FlowEquations( liouvillian, paramCons, rabiFreqs,
                            n_ar, sigmaQ_ar, rhoAtom ,dChi )
    
    if idxFish>=0:
        dLambda =paramCons[idxFish]*0.01
        if dLambda ==0:
            dLambda = 0.0001
        paramCons[idxFish]= paramCons[idxFish]  + dLambda
        
        flowEq1 = FlowEquations( liouvillian, paramCons, rabiFreqs,
                                n_ar, sigmaQ_ar , rhoAtom,dChi )
    else:
        dLambda = rhoAtom*0.0001
        flowEq1 = FlowEquations( liouvillian, paramCons, rabiFreqs, 
                                n_ar, sigmaQ_ar , rhoAtom+dLambda,dChi )

    z_ar = np.linspace(0,zMax,zNmb)
    out_ar = np.zeros( (zNmb,11) )
    z=0
    zNext = z_ar[0]
    i=0
    dzEff = dz *areaLaser*pulseTau
    while z< zMax+2*dz:
        
        if z >= zNext and i < zNmb:
        
            mean0_ar , sigmaQ0_ar = flowEq0.get_state()
            mean1_ar , sigmaQ1_ar = flowEq1.get_state()
            
    
            out_ar[i,0] = z_ar[i]
            out_ar[i,1] = mean0_ar[0]
            out_ar[i,2] = mean0_ar[1]
            out_ar[i,3] = sigmaQ0_ar[0,0]
            out_ar[i,4] = sigmaQ0_ar[0,1]
            out_ar[i,5] = sigmaQ0_ar[1,0]
            out_ar[i,6] = sigmaQ0_ar[1,1]
            

            n00 = mean0_ar[0]/2
            n01 = mean0_ar[0]/2
            nTotH = mean1_ar[0]/2
            delta_theta = mean1_ar[1] -mean0_ar[1]
            n10 = (np.cos(delta_theta)+ np.sin(delta_theta))*np.sqrt(nTotH)
            n11 = (np.cos(delta_theta)- np.sin(delta_theta))*np.sqrt(nTotH)
            n10*=n10
            n11*=n11
            n0_ar = np.array([n00,n01])
            n1_ar = np.array([n10,n11])
            out_ar[i,7:10]  = fisherInformation(n0_ar,sigmaQ0_ar,n1_ar,sigmaQ1_ar,dLambda)
            
            sigmaQEst0 = np.array( [[n00,0],[0,n01]] )
            sigmaQEst1 = np.array( [[n10,0],[0,n11]] )
            
            out_ar[i,10],_,_  = fisherInformation(n0_ar,sigmaQEst0 ,n1_ar,sigmaQEst1,dLambda)
            
            
            if showProgress == True:
                print('Progress:', i)
                print('exact: fisheInof: ',out_ar[i,7:10] )  
                
            i+=1    
            if i <= zNmb-1:
                zNext = z_ar[i]
            
            
     
        flowEq1.updateState(dzEff) 
        flowEq0.updateState(dzEff)     
        z+=dz
        
    
    
    return out_ar
    
            
   


if __name__ =='__main__':
    
    pass
