#!/usr/bin/env python3

import os
import sys
import numpy as np
import phoebe

from astropy import units as u

deg = np.pi/180.0
atol = 4.0e-16

def test_xyz_j2(verbose=False):
    """
    Test of b.default_binary().
    Non-equal masses (q = 0.1).
    Integration spanning 2 d.
    Using xyz.py integrator.
    J2 = -C20 (oblateness).

    """

    # to re-compute default_binary()
    phoebe.conf.devel_on()

    dir_ = os.path.dirname(os.path.realpath(__file__))

    b = phoebe.default_binary()

    b.set_value('q@binary', value=0.1)
    b.set_value('period@binary', value=1.0*u.d)
    b.set_value('ecc@binary', value=0.3)
    b.set_value('per0@binary', value=0.0*u.deg)
    b.set_value('incl@binary', value=10.0*u.deg)
    b.set_value('long_an@binary', value=-0.001*u.deg)
    b.set_value('t0_supconj@binary@component', value=0.0*u.d)

    b.set_value('t0@system', value=0.0)
    b.set_value('vgamma@system', value=0.0)
    b.set_value('distance@system', value=1.0*u.pc)
    b.set_value('requiv@primary@component', value=1.5*u.solRad)
    b.set_value('requiv@secondary@component', value=0.5*u.solRad)
    b.set_value('pitch@primary@component', value=-10.0*u.deg)

    if verbose:
        print("period = ", b['period@binary'].value)
        print("ecc = ", b['ecc@binary'].value)
        print("incl = ", b['incl@binary'].value)
        print("per0 = ", b['per0@binary'].value)
        print("long_an = ", b['long_an@binary'].value)
        print("t0_supconj = ", b['t0_supconj@binary'].value)
 
        print("mean_anom = ", b['mean_anom@binary@component'].value)
        print("t0 = ", b['t0@system'].value)
        print("vgamma = ", b['vgamma@system'].value)
        print("distance = ", b['distance@system'].value)
        print("ltte = ", b['ltte@compute'].value)
        print("m1 = ", b['mass@primary@component'].value)
        print("m2 = ", b['mass@secondary@component'].value)

        print("pitch1 = ", b['pitch@primary@star@component'].value)
        print("pitch2 = ", b['pitch@secondary@star@component'].value)
        print("yaw1 = ", b['yaw@primary@star@component'].value)
        print("yaw2 = ", b['yaw@secondary@star@component'].value)
        print("incl1 = ", b['incl@primary@star@component'].value)
        print("incl2 = ", b['incl@secondary@star@component'].value)
        print("long_an1 = ", b['long_an@primary@star@component'].value)
        print("long_an2 = ", b['long_an@secondary@star@component'].value)

        hier = b.get_hierarchy()
        print("hier = ", hier)
        print("hier.get_orbits() = ", hier.get_orbits())
   
    times = np.arange(0.0, 2.0+atol, 0.01)
    times = np.arange(0.0, 100.0+atol, 0.01)

    b.add_dataset('orb', compute_times=times)
#    b.add_dataset('mesh', compute_times=times)

    f = open('twigs.txt', 'w')
    for twig in b.twigs:
      f.write("%s\n" % (twig))
    f.close()

    b.run_compute(dynamics_method='keplerian')

    times = b['times@primary@orb01@phoebe01@latest@orb@model'].value
    us1 = b['us@primary@orb01@phoebe01@latest@orb@model'].value
    vs1 = b['vs@primary@orb01@phoebe01@latest@orb@model'].value
    ws1 = b['ws@primary@orb01@phoebe01@latest@orb@model'].value
    us2 = b['us@secondary@orb01@phoebe01@latest@orb@model'].value
    vs2 = b['vs@secondary@orb01@phoebe01@latest@orb@model'].value
    ws2 = b['ws@secondary@orb01@phoebe01@latest@orb@model'].value

    f = open("keplerian.dat", "w")
    f.write("# times star us vs ws\n")
    for i in range(len(times)):
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 0, us1[i], vs1[i], ws1[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 1, us2[i], vs2[i], ws2[i]))
    f.close()

########################################################################

    b2 = phoebe.default_binary()

    b2.set_value('q@binary', value=0.1)
    b2.set_value('ecc@binary', value=0.3)
    b2.set_value('per0@binary', value=0.0*u.deg)
    b2.set_value('incl@binary', value=10.0*u.deg)
    b2.set_value('requiv@primary@component', value=1.5*u.solRad)
    b2.set_value('requiv@secondary@component', value=0.5*u.solRad)
    b2.set_value('pitch@primary@component', value=-10.0*u.deg)
    b2.set_value('j2@primary@component', value=0.01)
    b2.set_value('j2@secondary@component', value=0.00)

    b2.add_dataset('orb', compute_times=times)

    b2.add_compute(dynamics_method='xyz', integrator='ias15', stepsize=1.0e-4, epsilon=1.0e-12, geometry='hierarchical', j2=True)
    b2.run_compute(compute='phoebe02')

    times = b2['times@primary@orb01@phoebe02@latest@orb@model'].value
    us1_ = b2['us@primary@orb01@phoebe02@latest@orb@model'].value
    vs1_ = b2['vs@primary@orb01@phoebe02@latest@orb@model'].value
    ws1_ = b2['ws@primary@orb01@phoebe02@latest@orb@model'].value
    us2_ = b2['us@secondary@orb01@phoebe02@latest@orb@model'].value
    vs2_ = b2['vs@secondary@orb01@phoebe02@latest@orb@model'].value
    ws2_ = b2['ws@secondary@orb01@phoebe02@latest@orb@model'].value

    f = open("rebound.dat", "w")
    f.write("# times star us vs ws\n")
    for i in range(len(times)):
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 0, us1_[i], vs1_[i], ws1_[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 1, us2_[i], vs2_[i], ws2_[i]))
    f.close()

    assert(np.allclose(us1, us1_, atol=atol, rtol=0.0))
    assert(np.allclose(vs1, vs1_, atol=atol, rtol=0.0))
    assert(np.allclose(ws1, ws1_, atol=atol, rtol=0.0))
    assert(np.allclose(us2, us2_, atol=atol, rtol=0.0))
    assert(np.allclose(vs2, vs2_, atol=atol, rtol=0.0))
    assert(np.allclose(ws2, ws2_, atol=atol, rtol=0.0))

if __name__ == "__main__":
    logger = phoebe.logger(clevel='INFO')
#    logger = phoebe.logger(clevel='DEBUG')

    test_xyz_j2(verbose=True)


