#!/usr/bin/env python3

import os
import sys
import numpy as np
import phoebe
import rebound

from phoebe import u
from phoebe import c

deg = np.pi/180.0
atol = 4.0e-16

def test_3body(verbose=False):

    # to re-compute default_triple()
    phoebe.conf.devel_on()

    dir_ = os.path.dirname(os.path.realpath(__file__))

    # default parameters
    b = phoebe.default_triple()

    if verbose:
        print("q = ", b['q@inner'].value)
        print("q = ", b['q@outer'].value)
        print("m1 = ", b['mass@starA@component'].value)
        print("m2 = ", b['mass@starB@component'].value)
        print("m3 = ", b['mass@starC@component'].value)
   
    # keplerian 
    times = np.arange(0.0, 1.0+atol, 1.0)
    times = np.arange(0.0, 1.0+atol, 0.01)

    b.add_dataset('orb', compute_times=times)

    f = open('twigs.txt', 'w')
    for twig in b.twigs:
      f.write("%s\n" % (twig))
    f.close()

    b.run_compute(dynamics_method='keplerian')

    times = b['times@starA@orb01@phoebe01@latest@orb@model'].value
    us1 = b['us@starA@orb01@phoebe01@latest@orb@model'].value
    vs1 = b['vs@starA@orb01@phoebe01@latest@orb@model'].value
    ws1 = b['ws@starA@orb01@phoebe01@latest@orb@model'].value
    us2 = b['us@starB@orb01@phoebe01@latest@orb@model'].value
    vs2 = b['vs@starB@orb01@phoebe01@latest@orb@model'].value
    ws2 = b['ws@starB@orb01@phoebe01@latest@orb@model'].value
    us3 = b['us@starC@orb01@phoebe01@latest@orb@model'].value
    vs3 = b['vs@starC@orb01@phoebe01@latest@orb@model'].value
    ws3 = b['ws@starC@orb01@phoebe01@latest@orb@model'].value

    f = open("keplerian.dat", "w")
    f.write("# times star us vs ws\n")
    for i in range(len(times)):
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 1, us1[i], vs1[i], ws1[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 2, us2[i], vs2[i], ws2[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 3, us3[i], vs3[i], ws3[i]))
    f.close()

    # rebound
    b2 = phoebe.default_triple()
    b2.add_compute(dynamics_method='rebound', integrator='ias15')
    b2.add_dataset('orb', compute_times=times)
    b2.run_compute(compute='phoebe02')

    times = b2['times@starA@orb01@phoebe02@latest@orb@model'].value
    us1_ = b2['us@starA@orb01@phoebe02@latest@orb@model'].value
    vs1_ = b2['vs@starA@orb01@phoebe02@latest@orb@model'].value
    ws1_ = b2['ws@starA@orb01@phoebe02@latest@orb@model'].value
    us2_ = b2['us@starB@orb01@phoebe02@latest@orb@model'].value
    vs2_ = b2['vs@starB@orb01@phoebe02@latest@orb@model'].value
    ws2_ = b2['ws@starB@orb01@phoebe02@latest@orb@model'].value
    us3_ = b2['us@starC@orb01@phoebe02@latest@orb@model'].value
    vs3_ = b2['vs@starC@orb01@phoebe02@latest@orb@model'].value
    ws3_ = b2['ws@starC@orb01@phoebe02@latest@orb@model'].value

    f = open("rebound.dat", "w")
    f.write("# times star us vs ws\n")
    for i in range(len(times)):
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 1, us1_[i], vs1_[i], ws1_[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 2, us2_[i], vs2_[i], ws2_[i]))
        f.write("%23.16e %2d %23.16e %23.16e %23.16e\n" % (times[i], 3, us3_[i], vs3_[i], ws3_[i]))
    f.close()

    assert(np.allclose(us1, us1_, atol=atol, rtol=0.0))
    assert(np.allclose(vs1, vs1_, atol=atol, rtol=0.0))
    assert(np.allclose(ws1, ws1_, atol=atol, rtol=0.0))
    assert(np.allclose(us2, us2_, atol=atol, rtol=0.0))
    assert(np.allclose(vs2, vs2_, atol=atol, rtol=0.0))
    assert(np.allclose(ws2, ws2_, atol=atol, rtol=0.0))
    assert(np.allclose(us3, us3_, atol=atol, rtol=0.0))
    assert(np.allclose(vs3, vs3_, atol=atol, rtol=0.0))
    assert(np.allclose(ws3, ws3_, atol=atol, rtol=0.0))

if __name__ == "__main__":
    logger = phoebe.logger(clevel='INFO')
#    logger = phoebe.logger(clevel='DEBUG')

    test_3body(verbose=True)


