"""
Test problems in nested calls.
Usually due to invalid type conversion between function boundaries.
"""


from numba import int32, int64
from numba import jit, generated_jit
from numba.core import types
from numba.tests.support import TestCase, tag
import unittest


@jit(nopython=True)
def f_inner(a, b, c):
    return a, b, c

def f(x, y, z):
    return f_inner(x, c=y, b=z)

@jit(nopython=True)
def g_inner(a, b=2, c=3):
    return a, b, c

def g(x, y, z):
    return g_inner(x, b=y), g_inner(a=z, c=x)

@jit(nopython=True)
def star_inner(a=5, *b):
    return a, b

def star(x, y, z):
    return star_inner(a=x), star_inner(x, y, z)

def star_call(x, y, z):
    return star_inner(x, *y), star_inner(*z)

@jit(nopython=True)
def argcast_inner(a, b):
    if b:
        # Here `a` is unified to int64 (from int32 originally)
        a = int64(0)
    return a

def argcast(a, b):
    return argcast_inner(int32(a), b)

@generated_jit(nopython=True)
def generated_inner(x, y=5, z=6):
    if isinstance(x, types.Complex):
        def impl(x, y, z):
            return x + y, z
    else:
        def impl(x, y, z):
            return x - y, z
    return impl

def call_generated(a, b):
    return generated_inner(a, z=b)


class TestNestedCall(TestCase):

    def compile_func(self, pyfunc, objmode=False):
        def check(*args, **kwargs):
            expected = pyfunc(*args, **kwargs)
            result = f(*args, **kwargs)
            self.assertPreciseEqual(result, expected)
        flags = dict(forceobj=True) if objmode else dict(nopython=True)
        f = jit(**flags)(pyfunc)
        return f, check

    def test_boolean_return(self):
        @jit(nopython=True)
        def inner(x):
            return not x

        @jit(nopython=True)
        def outer(x):
            if inner(x):
                return True
            else:
                return False

        self.assertFalse(outer(True))
        self.assertTrue(outer(False))

    def test_named_args(self, objmode=False):
        """
        Test a nested function call with named (keyword) arguments.
        """
        cfunc, check = self.compile_func(f, objmode)
        check(1, 2, 3)
        check(1, y=2, z=3)

    def test_named_args_objmode(self):
        self.test_named_args(objmode=True)

    def test_default_args(self, objmode=False):
        """
        Test a nested function call using default argument values.
        """
        cfunc, check = self.compile_func(g, objmode)
        check(1, 2, 3)
        check(1, y=2, z=3)

    def test_default_args_objmode(self):
        self.test_default_args(objmode=True)

    def test_star_args(self):
        """
        Test a nested function call to a function with *args in its signature.
        """
        cfunc, check = self.compile_func(star)
        check(1, 2, 3)

    def test_star_call(self, objmode=False):
        """
        Test a function call with a *args.
        """
        cfunc, check = self.compile_func(star_call, objmode)
        check(1, (2,), (3,))

    def test_star_call_objmode(self):
        self.test_star_call(objmode=True)

    def test_argcast(self):
        """
        Issue #1488: implicitly casting an argument variable should not
        break nested calls.
        """
        cfunc, check = self.compile_func(argcast)
        check(1, 0)
        check(1, 1)

    def test_call_generated(self):
        """
        Test a nested function call to a generated jit function.
        """
        cfunc = jit(nopython=True)(call_generated)
        self.assertPreciseEqual(cfunc(1, 2), (-4, 2))
        self.assertPreciseEqual(cfunc(1j, 2), (1j + 5, 2))


if __name__ == '__main__':
    unittest.main()
