#! /usr/bin/env python3
"""
# fitoffset is part of GMTSAR. 
# It is migrated from fitoffset.csh, originally written by David Sandwell, Dec 10, 2007.
# Dunyu Liu, 20230426.

# Purpose: to compute 2-6 alignment parameters from xcorr.dat
#   based on equations 
#   dr = c0 + c1*rm + c2*am (1)
#   da = c3 + c4*rm + c5*am (2)
#   where c0 = rshift + sub_int_r 
#       c1 = r_stretch
#       c2 = r_stretch_a 
#       c3 = ashift + sub_int_a
#       c4 = a_stretch_r
#       c5 = a_stretch.
# The data structure in xcorr.dat is
# r, dr, a, da, SNR for each row. 
# There should be at least 8 rows for sufficient inversion of the coefficients.
    
# Syntax: fitoffset npar_rng npar_azi xcorr.dat [SNR]
"""

import sys, os, re, configparser
import subprocess, glob, shutil, shlex
import numpy as np
from gmtsar_lib import * 

def fitoffset():

    n = len(sys.argv)
    print('FITOFFSET: ', sys.argv, 'is being called ... ...')

    def Error_Message():
        print(" ")
        print( "Usage: fitoffset npar_rng npar_azi xcorr.dat [SNR]")
        print(" ")
        print("        npar_rng    - number of parameters to fit in range ")
        print("        npar_azi    - number of parameters to fit in azimuth ")
        print("        xcorr.dat   - files of range and azimuth offset estimates ")
        print("        SNR         - optional SNR cutoff (default 20)")
        print("  ")
        print("Example: fitoffset.csh 3 3 freq_xcorr.dat ")
        print("  ")
    
    print('FITOFFSET - START ... ...')
    
    if n==1:
        print('FITOFFSET: zero # of input arguments ... ...')
        Error_Message()
    
    print(' ')
    # Counting fitoffset, there should be 5 args.
    if n == 5:
        SNR = float(sys.argv[4]) # 4th argument SNR
    else:
        SNR = 20.
    # endif 
    
    npar_rng = int(sys.argv[1])
    npar_azi = int(sys.argv[2])
    
    print('FITOFFSET: predefined SNR threshold SNR_thres is', SNR)
    print('FITOFFSET: if SNR in the dataset <= SNR_thres, then discard the data point ... ...')
    print('FITOFFSET: first extract the range and azimuth data ... ...')
    print('FITOFFSET: operating on file ', sys.argv[3], ' - Should be xcorr.dat ... ...')
    
    print('FITOFFSET: load rows of r, dr, a, da, SNR from xcorr.dat ... ...')
    xcorr = np.loadtxt(sys.argv[3])
    print('Loaded xcorr is', xcorr)
    

    n = len(xcorr[:,0])
    print('FITOFFSET: all data point # is ', n)
    xcorr1 = np.zeros((n,5))
    
    k = 0
    for i in range(n):
        if xcorr[i,4]> SNR: # 5th column xcorr[i,4] is SNR
            xcorr1[k,:] = xcorr[i,:]
            k = k + 1
    # print('k is ',k)
    
    # Columns are r, dr, a, da, SNR
    xcorr_r   = xcorr1[:k,0]
    xcorr_dr  = xcorr1[:k,1]
    xcorr_a   = xcorr1[:k,2]
    xcorr_da  = xcorr1[:k,3]
    xcorr_SNR = xcorr1[:k,4]
    
    n1 = len(xcorr_r)
    print('FITOFFSET: selected data point # is ', n1)
    
    range_r = [np.amin(xcorr_r),np.amax(xcorr_r)] 
    range_a = [np.amin(xcorr_a),np.amax(xcorr_a)]
    
    print('FITOFFSET: range of selected xcorr_r', range_r)
    print('FITOFFSET: range of selected xcorr_a', range_a)
    
    def get_coeff_trend2d(r,a,dra,npar):
        with open('xyz','w') as f:
            for i in range(k):
                f.write(str(r[i])+' '+ str(a[i])+' '+str(dra[i])+'\n')
        # cmd = 'gmt trend2d xyz -Fxyz -N'+str(npar_rng)+'r -V > /dev/null'
        tmp = subprocess.run(['gmt','trend2d','xyz','-Fxyz','-N'+str(npar_rng)+'r','-V'], 
                    stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        verbose_message = tmp.stderr
        for line in verbose_message.split('\n'):
            if 'Coefficients' in line:
                # line should look like ['trend2d', '[INFORMATION]:', 'Model', 'Coefficients:', '8.05879597667', '0.0560472001281']
                coeff = line.split() 
                c = np.zeros(3)
                for i in range(npar):
                    c[i]=c[i]+float(coeff[4+i])
                #print(coeff, line)
        delete('xyz')
        return c
    
    c012 = get_coeff_trend2d(xcorr_r, xcorr_a, xcorr_dr,npar_rng)
    c345 = get_coeff_trend2d(xcorr_r, xcorr_a, xcorr_da,npar_azi)
    print(c012,c345)
    
    def get_final_coeff(c,range_r,range_a):
        shift0  = c[0] - c[1]*(range_r[0]+range_r[1])/(range_r[1]-range_r[0])- c[2]*(range_a[0]+range_a[1])/(range_a[1]-range_a[0])
        #shift   = np.floor(shift0)
        #sub_int = shift0 - shift 
        if shift0>=0:
            shift = int(shift0) # floor division for non-negative numbers
        else: 
            shift = int(shift0)-1 # adjust for negative numbers
            
        sub_int = shift0 % 1 # get fractional part
        
        stretch_x   = c[1]*2./(range_r[1]-range_r[0]) 
        y_stretch_x = c[2]*2./(range_a[1]-range_a[0])
        return shift, sub_int, stretch_x, y_stretch_x 
    
    rshift, sub_int_r, stretch_r, a_stretch_r = get_final_coeff(c012, range_r, range_a)
    ashift, sub_int_a, stretch_a, r_stretch_a = get_final_coeff(c345, range_r, range_a)
    
    print('rshift      = ', rshift)    # c0
    print('sub_int_r   = ', sub_int_r) #
    print('stretch_r   = ', stretch_r) # c1 
    #print('r_stretch_a = ', r_stretch_a) # c2
    print('a_stretch_a = ', r_stretch_a) # c2
    print('ashift      = ', ashift)    # c3  
    print('sub_int_a   = ', sub_int_a) #
    print('a_stretch_r = ', a_stretch_r) # c4
    print('stretch_a   = ', stretch_a) # c5
    
    #print('out is', verbose_message)
    
    # The following commented block is for least square inversion of the coefficients.
    # # create the A matrix for inversion
    # # Col 1, 2, 3 are 
    # #     1, r, a, respectively.
    # # For npar_rng/azi = 1/2/3, A has different columns.
    
    # A1 = np.zeros((n1,1))
    # A1[:,0] = np.ones(n1)
    
    # A2 = np.zeros((n1,2))
    # A2[:,0] = np.ones(n1)
    # A2[:,1] = xcorr_r
    
    # A3 = np.zeros((n1,3))
    # A3[:,0] = np.ones(n1)
    # A3[:,1] = xcorr_r
    # A3[:,2] = xcorr_a
    
    # #print(A)
    
    # def invt_coef(A, npar, xcorr_dr, xcorr_da):
        
        # A_mat = np.mat(A)
        # dr_mat = np.mat(xcorr_dr).T
        # da_mat = np.mat(xcorr_da).T
        # Ainv= np.linalg.inv(A_mat.T*A_mat)
        # return Ainv, A_mat, dr_mat, da_mat
        
    # if npar_rng == 1:
        # Ainv, A_mat, dr_mat, da_mat = invt_coef(A1,1,xcorr_dr,xcorr_da)
    # elif npar_rng == 2:
        # Ainv, A_mat, dr_mat, da_mat = invt_coef(A2,2,xcorr_dr,xcorr_da)
    # elif npar_rng == 3:
        # Ainv, A_mat, dr_mat, da_mat = invt_coef(A3,3,xcorr_dr,xcorr_da)

    # c_r_tmp = Ainv*A_mat.T*dr_mat 
    # c012    = np.zeros((3))
    # for i in range(npar_rng):
        # c012[i] = c012[i] + c_r_tmp[i,0]
    
    # if npar_azi == 1:
        # Ainv, A_mat, dr_mat, da_mat = invt_coef(A1,1,xcorr_dr,xcorr_da)
    # elif npar_azi == 2:
        # Ainv, A_mat, dr_mat, da_mat = invt_coef(A2,2,xcorr_dr,xcorr_da)
    # elif npar_azi == 3:
        # Ainv, A_mat, dr_mat, da_mat = invt_coef(A3,3,xcorr_dr,xcorr_da)
    
    # c_a_tmp = Ainv*A_mat.T*da_mat
    # c345    = np.zeros((3))
    # for i in range(npar_azi):
        # c345[i] = c345[i] + c_a_tmp[i,0]
        
    # print('c0, c1, c2 are', c012)
    # print('c3, c4, c5 are', c345)
    
    # rshift    = np.floor(c012[0])
    # sub_int_r = c012[0] - rshift 
    
    # ashift    = np.floor(c345[0])
    # sub_int_a = c345[0] - ashift
    
    # print('rshift = ', rshift) 
    # print('sub_int_r = ', sub_int_r)
    # print('r_stretch   ', c012[1])
    # print('r_stretch_a ', c012[2]) 
    # print('ashift = ', ashift)
    # print('sub_int_a = ', sub_int_a)
    # print('a_stretch_r = ', c345[1])
    # print('a_stretch = ', c345[2])
    
    print("FITOFFSET - END ... ...")

def _main_func(description):
    fitoffset()

if __name__ == "__main__":
    _main_func(__doc__)

