#!/usr/bin/env python3
# coding: utf-8

import numpy as np
from scipy.integrate import solve_bvp
import par
from opt_dfuns import odefun, bcfun, init_guess  # , gradodefun
from loikkaPiirto import optPiirto, tallenna_tila, tallenna_ohjaus

"""
Ratkaistaan kahden pisteen reuna-arvotehtävä oheisella menetelmällä:
https://rjleveque.github.io/amath585w2020/notebooks/html/BVP_scipy.html


Numeerisen algoritmin konvergointi eli se, löytääkö algoritmi ratkaisun,
riippuu alkuarvauksesta ja tehtävään liittyvistä parametreista.

"""


def iniGuess(stps):
    """
    Alkuarvaus ratkaisulle. init_guess() palauttaa tilan halutun arvon alussa
    ja lopussa. Alkuarvauksena systeemi siirtyy suoraviivaisesti alkutilasta
    lopputilaan.

    Liittotiloille annetaan alkuarvaukseksi 0
    """
    guess = np.zeros((4*par.sysdim, stps))
    ga, gb = init_guess(0)
    for i in range(stps):
        # paikat
        for j in range(par.sysdim, 2*par.sysdim):
            guess[j, i] = (ga[j] + (gb[j] - ga[j])*float(i)/float(stps))
        # nopeudet
        for j in range(0, par.sysdim):
            guess[j, i] = ((gb[j] - ga[j])/par.Tf)
        # liittotilat
        for j in range(2*par.sysdim, 4*par.sysdim):
            guess[j, i] = 0.0

    """
    Algoritmi konvergoi huonosti, joten lisäsin allkuarvaukseen vihjeeksi,
    että takimmaisen jalan kannattaa potkaista alkuvauhtia
    """
#    iax = 6
    iay = 7
    ibx = 8
    iby = 9
    icx = 10
    icy = 11
#    iavx = 0
#    iavy = 1
    ibvx = 2
    ibvy = 3
    icvx = 4
    icvy = 5
    T1 = int(0.2*stps)
    T2 = int(0.4*stps)
    T3 = int(0.8*stps)
    vx_ave = 2.0*par.ds/par.Tf
    vy_ave = 4.0*guess[iay, 0]/par.Tf
    dt = float(par.Tf/stps)
    for i in range(1, T1):
        guess[ibvx, i] = -vx_ave
        guess[ibx, i] = guess[ibx, i-1] - vx_ave*dt
        guess[ibvy, i] = -vy_ave
        guess[iby, i] = guess[iby, i-1] - vy_ave*dt

        guess[icvx, i] = vx_ave
        guess[icx, i] = guess[icx, i-1] + vx_ave*dt
        guess[icvy, i] = vy_ave
        guess[icy, i] = guess[icy, i-1] + vy_ave*dt

    for i in range(T1, T2):
        guess[ibvx, i] = vx_ave
        guess[ibx, i] = guess[ibx, i-1] + vx_ave*dt

        guess[ibvy, i] = 2.0*vy_ave
        guess[iby, i] = guess[iby, i-1] + 2.0*vy_ave*dt

    for i in range(T3, stps):
        guess[ibvx, i] = 0.0
        guess[ibx, i] = gb[ibx]

        guess[ibvy, i] = 0.0
        guess[iby, i] = gb[iby]

        guess[icvx, i] = 0.0
        guess[icx, i] = gb[icx]

        guess[icvy, i] = 0.0
        guess[icy, i] = gb[icy]

    return guess  # .transpose()


def scipy_ratkaisu(Stps, plotdim):
    """
    Ratkaistaan reuna-arvotehtävä
    """
    initMesh = np.linspace(0.0, par.Tf, Stps)
    guess_arr = iniGuess(Stps)
    res = solve_bvp(odefun, bcfun, initMesh, guess_arr,
                    verbose=2, tol=1.0e-5, max_nodes=4000)
    print('status ', res.status, res.message)
    print('******************************************')
    (n, ti) = res.y.shape
    print('ti ', ti)
    dim = min(ti, plotdim)
    xs = np.linspace(0, par.Tf, dim)
    ys = res.sol(xs)
    pi = 0
    # piirretään ratkaisun kuvaajat
    optPiirto('hyppy_' + str(pi), xs, ys, dim)
    # tallennetaan ratkaisu simulointia varten
    tallenna_ohjaus(ys, dim)
    tallenna_tila(xs, ys, dim)

    while res.status == 0:
        """
        Alkuarvauksella algoritmi ei aina konvergoi kuin "löysällä"
        optimointikriteerillä. Ratkaisua voi käyttää uutena parempana
        alkuarvauksena ja kiristää optimointikriteeriä.

        Tässä tapauksessa ratkaisua ei paljoa parantunut
        """
        pi = pi + 1
        par.C_l = 1.1*par.C_l
        par.C_u = 1.0*par.C_u
        print(f' Cl: {par.C_l} Cu: {par.C_u}')
        xs0 = np.linspace(0, par.Tf, int(ti/10.0))
        ys0 = res.sol(xs0)
        res = solve_bvp(odefun, bcfun, xs0, ys0,
                        verbose=2, tol=1.0e-5, max_nodes=4000)
        print('status ', res.status, res.message)
        (n, ti) = res.y.shape
        print('ti ', ti)
        dim = min(ti, plotdim)
        xs = np.linspace(0, par.Tf, dim)
        ys = res.sol(xs)
        optPiirto('hyppy_' + str(pi), xs, ys, min(len(res.x), plotdim))
        tallenna_ohjaus(ys, dim)
        tallenna_tila(xs, ys, dim)

###########
# Pääohjelma


STPS = 20  # 20
PlotDim = 500  # int(Stps)
scipy_ratkaisu(STPS, PlotDim)

print("se siitä")