#!/usr/bin/env python3

import os
import sys
import numpy as np
import phoebe
import rebound

import matplotlib.pyplot as plt

from phoebe import u
from phoebe import c

from phoebe.parameters import hierarchy as _hierarchy
from phoebe.parameters.constraint import _get_system_ps

deg = np.pi/180.0
atol = 4.0e-16

_skip_filter_checks = {'check_default': False, 'check_visible': False}

def _twopairs(b, orbit3, solve_for=None, **kwargs):

    hier = b.hierarchy

    orbit3_ps = _get_system_ps(b, orbit3)

    q3 = orbit3_ps.get_parameter(qualifier='q', **_skip_filter_checks)

    m = []
    for star in hier.get_stars():
        star_ps = _get_system_ps(b, star)
        m.append(star_ps.get_parameter(qualifier='mass', **_skip_filter_checks))

    if solve_for in [None]:
        lhs = q3
        rhs = (m[2]+m[3])/(m[0]+m[1])
    else:
        raise NotImplementedError

    return lhs, rhs, [q3] + m, {'orbit3': orbit3}

def default_twopairs():

    b = phoebe.Bundle()

    b.add_star(component='starA', color='red')
    b.add_star(component='starB', color='green')
    b.add_star(component='starC', color='blue')
    b.add_star(component='starD', color='cyan')

    b.add_orbit(component='orbit1')
    b.add_orbit(component='orbit2')
    b.add_orbit(component='orbit3')

    b.set_value('q@orbit1@component', value=1.0)
    b.set_value('q@orbit2@component', value=1.0)
    b.set_value('q@orbit3@component', value=100.0)
    b.set_value('period@orbit1@component', value=1.0)
    b.set_value('period@orbit2@component', value=1.5)
    b.set_value('period@orbit3@component', value=100.0)
    b.set_value('sma@orbit1@component', value=5.3)
    b.set_value('sma@orbit2@component', value=1.5*5.3)
    b.set_value('sma@orbit3@component', value=100.0)

    b.run_delayed_constraints()

    hier1 = _hierarchy.binaryorbit(b['orbit1'], b['starA'], b['starB'])
    hier2 = _hierarchy.binaryorbit(b['orbit2'], b['starC'], b['starD'])
    hier3 = _hierarchy.binaryorbit(b['orbit3'], hier1, hier2)

    b.set_hierarchy(hier3)

    b.add_constraint(phoebe.parameters.constraint.keplers_third_law_hierarchical, 'orbit3', 'orbit1')

    b.add_constraint(_twopairs, 'orbit3')

    b.add_compute()

    return b

def test_xyz_twopairs(verbose=False):
    """
    Test of twopairs geometry.

    """

    # to re-compute default_triple()
    phoebe.conf.devel_on()

    dir_ = os.path.dirname(os.path.realpath(__file__))

    # rebound
    b2 = default_twopairs()

    if verbose:
        print("q1 = ", b2['q@orbit1@component'].value)
        print("q2 = ", b2['q@orbit2@component'].value)
        print("q3 = ", b2['q@orbit3@component'].value)
        print("m1 = ", b2['mass@starA@component'].value)
        print("m2 = ", b2['mass@starB@component'].value)
        print("m3 = ", b2['mass@starC@component'].value)
        print("m4 = ", b2['mass@starD@component'].value)
        print("period1 = ", b2['period@orbit1@component'].value)
        print("period2 = ", b2['period@orbit2@component'].value)
        print("period3 = ", b2['period@orbit3@component'].value)
        print("sma1 = ", b2['sma@orbit1@component'].value)
        print("sma2 = ", b2['sma@orbit2@component'].value)
        print("sma3 = ", b2['sma@orbit3@component'].value)

        hier = b2.get_hierarchy()
        print("hier = ", hier)
        print("hier.get_orbits() = ", hier.get_orbits())

    b2.set_value('requiv@starA@component', value=2.0)
    b2.set_value('requiv@starB@component', value=2.0)
    b2.set_value('requiv@starC@component', value=2.0)
    b2.set_value('requiv@starD@component', value=2.0)

    times = np.arange(0.00, 0.50+atol, 0.01)
#    times = [0.25]

    b2.add_dataset('mesh', compute_times=times, columns=['visibilities'], dataset='mesh01')

    b2.add_compute(dynamics_method='xyz', integrator='ias15', geometry='twopairs')
    b2.run_compute(compute='phoebe02')

#    fig, plt = b2.plot(fc='visibilities', ec='None', y='vs', show=True)
#    plt.savefig("reboung.png", dpi=300)

    b2['mesh@model'].plot(fc='visibilities', ec='None', y='vs')
    fig = plt.figure(figsize=(11,4))
    afig, anim = b2.savefig('rebound.gif', fig=fig, tight_layouot=True, draw_sidebars=False, animate=True, save_kwargs={'writer': 'imagemagick'})


if __name__ == "__main__":
    logger = phoebe.logger(clevel='INFO')
#    logger = phoebe.logger(clevel='DEBUG')

    test_xyz_twopairs(verbose=True)


