#!/usr/bin/env python
"""
Multivalued current-phase relation for SSS junction
"""
from __future__ import division, absolute_import, print_function


import sys
import os
import numpy as np
from numpy import pi, arccosh, linspace, exp
from scipy import io, optimize
import matplotlib.pyplot as plt
import usadel1 as u

from usadel1.selfconsistentiteration import _set_self_consistent_delta_m, _get_self_consistent_energies_m, dump_delta


def main():
    ## Geometry

    Delta = 1
    omega_D = 600
    Delta_w = 1

    g = u.Geometry(nwire=1, nnode=2, x=np.linspace(0, 1, 201))

    g.t_type = [u.NODE_CLEAN_S_TERMINAL,
                u.NODE_CLEAN_S_TERMINAL]

    g.t_delta = [Delta, Delta]
    g.t_phase = [0, pi]
    g.t_inelastic  = 0
    g.t_t = 0.05
    g.t_mu = 0

    g.w_type = u.WIRE_TYPE_S
    g.w_length = 4.0
    g.w_conductance = 1
    g.w_inelastic = 0

    g.w_ends[0,:] = [0, 1]

    g.omega_D = omega_D
    g.coupling_lambda = 1/arccosh(g.omega_D/Delta_w)

    g.w_phase[0,:] = 0
    g.w_delta[0,:] = Delta_w + g.x*(1-g.x)*4*(Delta - Delta_w)

    ## Solve

    points = self_consistent_matsubara_cpr(g, max_ne=50)

    save_points(g, points)


def save_points(g, points):
    # Save to file
    if os.path.isfile('sss-cpr.npz'):
        os.unlink('sss-cpr.npz')
    np.savez('sss-cpr.npz', points=points)

    # Create a plot
    z = np.load('sss-cpr.npz', allow_pickle=True)['points'].copy()
    plt.clf()
    plt.plot(z[:,1].tolist(), z[:,2].tolist(), '-', label='__nolabel__')
    lims = plt.axis()

    # Plot Ambegaokar-Baratoff
    p = np.linspace(0, 2*pi, 200)
    delta_0 = abs(z[0,0]).min()
    L_AB = (pi*delta_0/2)*np.tanh(delta_0/(2*g.t_t[0]))/g.w_length[0]
    plt.plot(p, L_AB*p, 'k--', label='Ambegaokar-Baratoff')
    plt.axis(lims)

    plt.xlim(0, 4*pi)
    plt.xticks([0, pi, 2*pi, 3*pi, 4*pi],
               [r'$0$', r'$\pi$', r'$2\pi$', r'$3\pi$', r'$4\pi$'])
    plt.grid(1)
    plt.legend(loc='best')
    plt.savefig('sss-cpr.pdf')
    plt.savefig('sss-cpr.png')


def self_consistent_matsubara_cpr(geometry,
                                  phi_start=0,
                                  phi_end=4*pi,
                                  max_iterations=100,
                                  max_points=10000,
                                  max_ne=300,
                                  output_func=sys.stderr.write,
                                  E_max=None,
                                  force_integral=False,
                                  real_delta=False,
                                  tol=1e-3):
    """
    Compute current-phase relation by pseudo-arclength continuation.
    """

    if not ((np.diff(geometry.t_t) == 0).all() and (geometry.t_mu == 0).all()):
        raise ValueError("The geometry does not describe an equilibrium situation.")
    T = geometry.t_t[0]

    E = [[]]
    last_I = [None]

    theta = 1/ (1 + len(geometry.x)/(2*pi)**2) # arclength weight for Delta  (vs. that of phase)

    def F(delta):
        geometry.w_delta = abs(delta)
        geometry.w_phase = np.angle(delta)

        #if output_func is not None:
        #    dump_delta(geometry, output_func)

        E[0], weights = _get_self_consistent_energies_m(geometry, T, max_ne=max_ne,
                                                        E_max=E_max, ne=len(E[0]),
                                                        force_integral=force_integral)

        solver = u.Solver()
        solver.set_geometry(geometry)
        solver.set_solvers(sp_tol=0.01*tol, sp_solver=u.SP_SOLVER_COLNEW)
        sol = solver.sp_solve(E[0], geometry.x)

        _set_self_consistent_delta_m(geometry, sol, T, weights, output_func=output_func)
        delta_2 = geometry.w_delta[:,:] * np.exp(1j*geometry.w_phase[:,:])

        # Compute current from Matsubara
        # FIXME: check factors of 2 in the prefactor!
        j = ((sol.b * sol.da - sol.a * sol.db) / (1 - sol.a * sol.b)**2)[:,0,:]
        I = np.real(4j * pi * T * (j * weights[:,None]).sum(axis=0) * geometry.w_conductance[0])
        last_I[0] = I

        print("resid_1:", abs(delta_2 - delta).max(), ", I:", last_I[0].mean(), ", rel. I error:", np.ptp(I)/abs(I.mean()))

        return delta_2 - delta

    def get_g_array(delta, phase):
        n = len(delta.ravel())
        z = np.r_[np.array([delta.real.ravel(), delta.imag.ravel()]).T.ravel(),
                  phase * np.sqrt(n)]
        assert z.ndim == 1
        return z

    def unget_g_array(z):
        assert z.ndim == 1
        n = (len(z) - 1)//2
        re_delta = z[:2*n:2]
        im_delta = z[1:2*n:2]
        delta = re_delta + 1j*im_delta
        phase = z[-1] / np.sqrt(n)
        return delta, phase

    def G(z, ddelta, dphase, delta_0, phase_0, ds):
        delta, phase = unget_g_array(z)
        geometry.t_phase[:] = 0, phase

        resid_1 = F(delta)

        # pseudo-arclength condition
        resid_2 = theta * np.real(np.vdot(ddelta, delta - delta_0)) + (2 - theta) * dphase * (phase - phase_0) - ds

        return get_g_array(resid_1, resid_2)

    #if output_func is not None:
    #    dump_delta(geometry, output_func)

    ds = 0.3
    ds_max = 0.3

    # Nonlinear solver fine-tuning
    method = 'broyden2'
    options = dict(
        fatol=tol,
        xatol=np.inf,
        ftol=np.inf,
        xtol=1e-4,
        line_search=None,

        # we use smaller ds if it doesn't converge fast enough:
        maxiter=80,

        # the automatic scale estimation fails if starting point is
        # close to solution; alpha (jacobian) ~ 1.0 is appropriate
        # condition for fixed-point iteration:
        jac_options=dict(alpha=1.0),
        )

    # Compute two initial pseudo-arclength points
    delta_0 = geometry.w_delta * np.exp(1j*geometry.w_phase)
    phase_0 = phi_start
    geometry.t_phase = 0, phase_0
    geometry.w_phase = phase_0*geometry.x
    sol = optimize.root(F, delta_0, method=method, options=options)
    assert sol.success
    delta_0 = sol.x
    F(delta_0)
    I_0 = last_I[0]
    print("<----")
    dump_delta(geometry, output_func)
    print("phi =", phase_0, ", I =", I_0.mean())
    print("---->")

    phase_1 = phi_start + 0.05
    geometry.t_phase = 0, phase_1
    sol = optimize.root(F, delta_0, method=method, options=options)
    assert sol.success
    delta_1 = sol.x
    F(delta_1)
    I_1 = last_I[0]
    print("<----")
    dump_delta(geometry, output_func)
    print("phi =", phase_1, ", I =", I_1.mean())
    print("---->")

    # Record visited points
    points = [(delta_0, phase_0, I_0),
              (delta_1, phase_1, I_1)]

    for j in range(max_points):
        # Estimate derivatives
        last_ds = np.sqrt(theta * np.linalg.norm(delta_1 - delta_0)**2 + (2 - theta) * abs(phase_1 - phase_0)**2)
        ddelta = (delta_1 - delta_0) / last_ds
        dphase = (phase_1 - phase_0) / last_ds

        ds = ds_max
        for k in range(5):
            # Predictor
            print("STEP:", last_ds, ds)
            delta_next = delta_1 + ds*ddelta
            phase_next = phase_1 + ds*dphase
            print("PREDICTOR: phi =", phase_next)
            z_next = get_g_array(delta_next, phase_next)

            # Corrector
            try:
                sol = optimize.root(G, z_next, method=method, options=options,
                                    args=(ddelta, dphase, delta_1, phase_1, ds))
                if sol.success:
                    break
            except u.CoreError:
                pass
            ds /= 2.0
        else:
            raise RuntimeError("Did not converge.")

        delta_2, phase_2 = unget_g_array(sol.x)

        # Record
        geometry.t_phase = 0, phase_2
        F(delta_2)
        I_2 = last_I[0]
        points.append((delta_2, phase_2, I_2))
        print("<----")
        dump_delta(geometry, output_func)
        print("phi =", phase_2, ", I =", I_2.mean())
        print("---->")

        # Step forward
        delta_0, phase_0 = delta_1, phase_1
        delta_1, phase_1 = delta_2, phase_2

        # Save
        save_points(geometry, points)

        if phase_2 > phi_end:
            break

    return points


if __name__ == "__main__":
    main()
