#!/usr/bin/env python3

import os
import sys
import numpy as np
import phoebe
import rebound

from phoebe import u
from phoebe import c

from phoebe.parameters import hierarchy as _hierarchy

deg = np.pi/180.0
atol = 4.0e-16

def default_twopairs():

    b = phoebe.Bundle()

    b.add_star(component='starA', color='red')
    b.add_star(component='starB', color='green')
    b.add_star(component='starC', color='blue')
    b.add_star(component='starD', color='cyan')

    # THIS DOES NOT WORK, BECAUSE...
    b.add_orbit(component='orbit1')
    b.add_orbit(component='orbit2')
    b.add_orbit(component='orbit3')

    # ... NO CONTRAINTS => NEED TO SET UP SMA MANUALLY
    b.set_value('mass@starA@component', value=1.0)
    b.set_value('mass@starB@component', value=1.0)
    b.set_value('mass@starC@component', value=1.0)
    b.set_value('mass@starD@component', value=1.0)
    b.set_value('period@orbit1@component', value=1.0)
    b.set_value('period@orbit2@component', value=1.5)
    b.set_value('period@orbit3@component', value=100.0)
    b.set_value('sma@orbit1@component', value=5.3)
    b.set_value('sma@orbit2@component', value=1.5*5.3)
    b.set_value('sma@orbit3@component', value=100.0)

    b.run_delayed_constraints()

    hier1 = _hierarchy.binaryorbit(b['orbit1'], b['starA'], b['starB'])
    hier2 = _hierarchy.binaryorbit(b['orbit2'], b['starC'], b['starD'])
    hier3 = _hierarchy.binaryorbit(b['orbit3'], hier1, hier2)

    b.set_hierarchy(hier3)

#    b.add_constraint(phoebe.parameters.constraint.keplers_third_law_hierarchical, 'orbit2', 'orbit1')
#    b.add_constraint(phoebe.parameters.constraint.keplers_third_law_hierarchical, 'orbit3', 'orbit2')

    b.add_compute()

    return b

def test_xyz_twopairs(verbose=False):
    """
    Test of twopairs geometry.

    """

    # to re-compute default_triple()
    phoebe.conf.devel_on()

    dir_ = os.path.dirname(os.path.realpath(__file__))

    times = np.arange(0.0, 2.0+atol, 0.01)

    # rebound
    b2 = default_twopairs()

    b2.set_value('incl@orbit2@component', value=0.0*u.deg)

    b2.add_dataset('orb', compute_times=times)

    b2.add_compute(dynamics_method='xyz', integrator='ias15', geometry='twopairs')
    b2.run_compute(compute='phoebe02')

    times = b2['times@starA@orb01@phoebe02@latest@orb@model'].value

    us_, vs_, ws_, vus_, vvs_, vws_ = [], [], [], [], [], []
    for j, star in enumerate(['starA', 'starB', 'starC', 'starD']):
        us_.append(b2['us@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vs_.append(b2['vs@%s@orb01@phoebe02@latest@orb@model' % star].value)
        ws_.append(b2['ws@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vus_.append(b2['vus@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vvs_.append(b2['vvs@%s@orb01@phoebe02@latest@orb@model' % star].value)
        vws_.append(b2['vws@%s@orb01@phoebe02@latest@orb@model' % star].value)

    f = open("rebound.dat", "w")
    f.write("# times star us vs ws vus vvs vws\n")
    for i in range(len(times)):
        for j in range(4):
            f.write("%23.16e %2d %23.16e %23.16e %23.16e %23.16e %23.16e %23.16e\n" % (times[i], j, us_[j][i], vs_[j][i], ws_[j][i], vus_[j][i], vvs_[j][i], vws_[j][i]))
    f.close()

if __name__ == "__main__":
    logger = phoebe.logger(clevel='INFO')
#    logger = phoebe.logger(clevel='DEBUG')

    test_xyz_twopairs(verbose=True)


