#!/usr/bin/env python3

import os
import sys
import numpy as np
import phoebe

from astropy import units as u

deg = np.pi/180.0
atol = 4.0e-16

def test_2body(verbose=False):
    """
    Test of b.default_binary().
    Non-equal masses (q = 0.1).
    Integration spanning 2 d.

    """

    # to re-compute default_binary()
    phoebe.conf.devel_on()

    dir_ = os.path.dirname(os.path.realpath(__file__))

    b = phoebe.default_binary()

    b.set_value('q@binary', value=0.1)
    b.set_value('period@binary', value=1.0*u.d)
    b.set_value('ecc@binary', value=0.0)
    b.set_value('incl@binary', value=90.0*u.deg)
    b.set_value('per0@binary', value=0.0)
    b.set_value('long_an@binary', value=0.0)
    b.set_value('t0_supconj@binary@component', value=0.0*u.d)

    b.set_value('t0@system', value=0.0)
    b.set_value('vgamma@system', value=-15.0*u.km/u.s)
    b.set_value('distance@system', value=1.0*u.pc)

    if verbose:
        print("period = ", b['period@binary'].value)
        print("ecc = ", b['ecc@binary'].value)
        print("incl = ", b['incl@binary'].value)
        print("per0 = ", b['per0@binary'].value)
        print("long_an = ", b['long_an@binary'].value)
        print("t0_supconj = ", b['t0_supconj@binary'].value)
 
        print("t0 = ", b['t0@system'].value)
        print("vgamma = ", b['vgamma@system'].value)
        print("distance = ", b['distance@system'].value)
        print("ltte = ", b['ltte@compute'].value)
        print("m1 = ", b['mass@primary@component'].value)
        print("m2 = ", b['mass@secondary@component'].value)
   
#    times = np.arange(0.0, 2.0+atol, 0.01)
#    b.add_dataset('orb', compute_times=times)
#    b.add_dataset('mesh', compute_times=times)
    times, rvs, sigmas = np.loadtxt('Rv.dat', usecols=(0,1,2), unpack=True)
    b.add_dataset(kind='rv', times=times, rvs=rvs, sigmas=sigmas)

    f = open('twigs.txt', 'w')
    for twig in b.twigs:
      f.write("%s\n" % (twig))
    f.close()

    b.run_compute(dynamics_method='keplerian')

    times = b['times@primary@rv01@phoebe01@latest@rv@model'].value
    rv1 = b['rvs@primary@rv01@phoebe01@latest@rv@model'].value
    rv2 = b['rvs@secondary@rv01@phoebe01@latest@rv@model'].value

    f = open("keplerian.dat", "w")
    f.write("# times rv1 rv2\n")
    for i in range(len(times)):
        f.write("%23.16e %23.16e %23.16e\n" % (times[i], rv1[i], rv2[i]))
    f.close()

#    fig, plt = b.plot(show=True)
#    plt.savefig("keplerian.png")

########################################################################

    b2 = phoebe.default_binary()
    b2.set_value('q@binary', value=0.1)
    b2.set_value('vgamma@system', value=-15.0*u.km/u.s)
    b2.add_compute(dynamics_method='rebound', integrator='ias15', stepsize=1.0e-4, epsilon=1.0e-12)
    b2.add_dataset(kind='rv', times=times, rvs=rvs, sigmas=sigmas)
    b2.run_compute(compute='phoebe02')

    times = b2['times@primary@rv01@phoebe02@latest@rv@model'].value
    rv1_ = b2['rvs@primary@rv01@phoebe02@latest@rv@model'].value
    rv2_ = b2['rvs@secondary@rv01@phoebe02@latest@rv@model'].value

    f = open("rebound.dat", "w")
    f.write("# times rv1 rv2\n")
    for i in range(len(times)):
        f.write("%23.16e %23.16e %23.16e\n" % (times[i], rv1_[i], rv2_[i]))
    f.close()


    assert(np.allclose(rv1, rv1_, atol=atol, rtol=0.0))
    assert(np.allclose(rv2, rv2_, atol=atol, rtol=0.0))

if __name__ == "__main__":
    logger = phoebe.logger(clevel='INFO')
#    logger = phoebe.logger(clevel='DEBUG')

    test_2body(verbose=True)


