#!/usr/bin/env python3

import os
import sys
import numpy as np
import phoebe
import rebound

from phoebe import u
from phoebe import c

deg = np.pi/180.0
atol = 4.0e-16

def test_xyz_3body(verbose=False):
    """
    Test of b.default_triple().

    """

    # to re-compute default_triple()
    phoebe.conf.devel_on()

    dir_ = os.path.dirname(os.path.realpath(__file__))

    # default parameters
    b = phoebe.default_triple()

    b['q@outer'] = 0.5

    if verbose:
        print("q1 = ", b['q@inner'].value)
        print("q2 = ", b['q@outer'].value)
        print("m1 = ", b['mass@starA@component'].value)
        print("m2 = ", b['mass@starB@component'].value)
        print("m3 = ", b['mass@starC@component'].value)
        print("period1 = ", b['period@inner@component'].value)
        print("period2 = ", b['period@outer@component'].value)
        print("sma1 = ", b['sma@inner@component'].value)
        print("sma2 = ", b['sma@outer@component'].value)

        hier = b.get_hierarchy()
        print("hier = ", hier)
        print("hier.get_orbits() = ", hier.get_orbits())
   
    # keplerian 
    times = np.arange(0.0, 10.0+atol, 0.01)

    b.add_dataset('orb', compute_times=times)

    f = open('twigs.txt', 'w')
    for twig in b.twigs:
      f.write("%s\n" % (twig))
    f.close()

    b.run_compute(dynamics_method='keplerian')

    times = b['times@starA@orb01@phoebe01@latest@orb@model'].value

    us, vs, ws, vus, vvs, vws = [], [], [], [], [], []
    for j, star in enumerate(['starA', 'starB', 'starC']):
        us.append(b['us@%s@orb01@phoebe01@latest@orb@model' % star].value)
        vs.append(b['vs@%s@orb01@phoebe01@latest@orb@model' % star].value)
        ws.append(b['ws@%s@orb01@phoebe01@latest@orb@model' % star].value)
        vus.append(b['vus@%s@orb01@phoebe01@latest@orb@model' % star].value)
        vvs.append(b['vvs@%s@orb01@phoebe01@latest@orb@model' % star].value)
        vws.append(b['vws@%s@orb01@phoebe01@latest@orb@model' % star].value)

    f = open("keplerian.dat", "w")
    f.write("# times star us vs ws vus vvs vws\n")
    for i in range(len(times)):
        for j in range(3):
            f.write("%23.16e %2d %23.16e %23.16e %23.16e %23.16e %23.16e %23.16e\n" % (times[i], j, us[j][i], vs[j][i], ws[j][i], vus[j][i], vvs[j][i], vws[j][i]))
    f.close()

    # rebound
    b2 = phoebe.default_triple()

    b2['q@outer'] = 0.5

    b2.add_dataset('orb', compute_times=times)

    b2.add_compute(dynamics_method='xyz', integrator='ias15', geometry='hierarchical')
    b2.run_compute(compute='phoebe02')

    times = b2['times@starA@orb01@phoebe02@latest@orb@model'].value

    us_, vs_, ws_, vus_, vvs_, vws_ = [], [], [], [], [], []
    for j, star in enumerate(['starA', 'starB', 'starC']):
        us_.append(b2['us@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vs_.append(b2['vs@%s@orb01@phoebe02@latest@orb@model' % star].value)
        ws_.append(b2['ws@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vus_.append(b2['vus@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vvs_.append(b2['vvs@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vws_.append(b2['vws@%s@orb01@phoebe02@latest@orb@model' % star].value)

    f = open("rebound.dat", "w")
    f.write("# times star us vs ws vus vvs vws\n")
    for i in range(len(times)):
        for j in range(3):
            f.write("%23.16e %2d %23.16e %23.16e %23.16e %23.16e %23.16e %23.16e\n" % (times[i], j, us_[j][i], vs_[j][i], ws_[j][i], vus_[j][i], vvs_[j][i], vws_[j][i]))
    f.close()

    assert(np.allclose(us, us_, atol=atol, rtol=0.0))
    assert(np.allclose(vs, vs_, atol=atol, rtol=0.0))
    assert(np.allclose(ws, ws_, atol=atol, rtol=0.0))

if __name__ == "__main__":
    logger = phoebe.logger(clevel='INFO')
#    logger = phoebe.logger(clevel='DEBUG')

    test_xyz_3body(verbose=True)


