#!/usr/bin/env python3

import os
import sys
import numpy as np
import phoebe

from astropy import units as u

deg = np.pi/180.0
atol = 4.0e-16

def test_2body_dpdt(verbose=False):
    """
    Test of b.default_binary().
    Non-equal masses (q = 0.1).
    Integration spanning 2 d.
    Variable period.

    """

    # to re-compute default_binary()
    phoebe.conf.devel_on()

    dir_ = os.path.dirname(os.path.realpath(__file__))

    b = phoebe.default_binary()

    b.set_value('q@binary', value=0.1)
    b.set_value('period@binary', value=1.0*u.d)
    b.set_value('ecc@binary', value=0.0)
    b.set_value('incl@binary', value=90.0*u.deg)
    b.set_value('per0@binary', value=0.0)
    b.set_value('long_an@binary', value=0.0)
    b.set_value('t0_supconj@binary@component', value=0.0*u.d)

    b.set_value('t0@system', value=0.0)
    b.set_value('vgamma@system', value=0.0)
    b.set_value('distance@system', value=1.0*u.pc)
    b.set_value('dpdt@binary@component', value=0.1*u.d/u.d)

    if verbose:
        print("period = ", b['period@binary'].value)
        print("ecc = ", b['ecc@binary'].value)
        print("incl = ", b['incl@binary'].value)
        print("per0 = ", b['per0@binary'].value)
        print("long_an = ", b['long_an@binary'].value)
        print("t0_supconj = ", b['t0_supconj@binary'].value)
 
        print("t0 = ", b['t0@system'].value)
        print("vgamma = ", b['vgamma@system'].value)
        print("distance = ", b['distance@system'].value)
        print("ltte = ", b['ltte@compute'].value)
        print("m1 = ", b['mass@primary@component'].value)
        print("m2 = ", b['mass@secondary@component'].value)
        print("dpdt = ", b['dpdt@binary@component'])
   
    times = np.arange(0.0, 2.0+atol, 0.01)

    b.add_dataset('orb', compute_times=times)
#    b.add_dataset('mesh', compute_times=times)

    f = open('twigs.txt', 'w')
    for twig in b.twigs:
      f.write("%s\n" % (twig))
    f.close()

    b.run_compute(dynamics_method='keplerian')

    times = b['times@primary@orb01@phoebe01@latest@orb@model'].value
    us1 = b['us@primary@orb01@phoebe01@latest@orb@model'].value
    vs1 = b['vs@primary@orb01@phoebe01@latest@orb@model'].value
    ws1 = b['ws@primary@orb01@phoebe01@latest@orb@model'].value
    us2 = b['us@secondary@orb01@phoebe01@latest@orb@model'].value
    vs2 = b['vs@secondary@orb01@phoebe01@latest@orb@model'].value
    ws2 = b['ws@secondary@orb01@phoebe01@latest@orb@model'].value

    f = open("keplerian.dat", "w")
    f.write("# times star us vs ws\n")
    for i in range(len(times)):
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 1, us1[i], vs1[i], ws1[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 2, us2[i], vs2[i], ws2[i]))
    f.close()

########################################################################

    b2 = phoebe.default_binary()
    b2.set_value('q@binary', value=0.1)
    b2.set_value('dpdt@binary@component', value=0.1*u.d/u.d)
    b2.add_compute(dynamics_method='rebound', integrator='ias15', stepsize=1.0e-4, epsilon=1.0e-12)
    b2.add_dataset('orb', compute_times=times)
    b2.run_compute(compute='phoebe02')

    times = b2['times@primary@orb01@phoebe02@latest@orb@model'].value
    us1_ = b2['us@primary@orb01@phoebe02@latest@orb@model'].value
    vs1_ = b2['vs@primary@orb01@phoebe02@latest@orb@model'].value
    ws1_ = b2['ws@primary@orb01@phoebe02@latest@orb@model'].value
    us2_ = b2['us@secondary@orb01@phoebe02@latest@orb@model'].value
    vs2_ = b2['vs@secondary@orb01@phoebe02@latest@orb@model'].value
    ws2_ = b2['ws@secondary@orb01@phoebe02@latest@orb@model'].value

    f = open("rebound.dat", "w")
    f.write("# times star us vs ws\n")
    for i in range(len(times)):
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 1, us1_[i], vs1_[i], ws1_[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 2, us2_[i], vs2_[i], ws2_[i]))
    f.close()

    assert(np.allclose(us1, us1_, atol=atol, rtol=0.0))
    assert(np.allclose(vs1, vs1_, atol=atol, rtol=0.0))
    assert(np.allclose(ws1, ws1_, atol=atol, rtol=0.0))
    assert(np.allclose(us2, us2_, atol=atol, rtol=0.0))
    assert(np.allclose(vs2, vs2_, atol=atol, rtol=0.0))
    assert(np.allclose(ws2, ws2_, atol=atol, rtol=0.0))

if __name__ == "__main__":
    logger = phoebe.logger(clevel='INFO')
#    logger = phoebe.logger(clevel='DEBUG')

    test_2body_dpdt(verbose=True)


