#!/usr/bin/env python3

import numpy as np
import ndpolator

a1 = np.array([5900, 6000, 6100])
a2 = np.array([4.0, 4.5])
a3 = np.array([1.0, 1.000001])

ndp = ndpolator.Ndpolator(basic_axes=(a1, a2, a3))

grid = np.empty((len(a1), len(a2), len(a3), 1))

for i, x in enumerate(a1):
    for j, y in enumerate(a2):
        for k, z in enumerate(a3):
            grid[i, j, k, 0] = 1.0

ndp.register(table='main', associated_axes=None, grid=grid)

query_pts = [[6000, 4.25, 1.0]]

# NOTE: WITH 'nearest', IT WORKS (1.0), BUT THIS IS NOT AN EXTRAPOLATION!
#interps = ndp.ndpolate(table='main', query_pts=query_pts, extrapolation_method='nearest')

interps = ndp.ndpolate(table='main', query_pts=query_pts, extrapolation_method='linear')

print("a1 = ", a1)
print("a2 = ", a2)
print("a3 = ", a3)
print("query_pts = ", query_pts)
print("interps['interps'] = ", interps['interps'])


