import copy
import itertools
import math
import random
import sys
from typing import KeysView

import numpy as np

from numba.core.compiler import compile_isolated, Flags
from numba import jit, njit
from numba.core import types, utils, errors
import unittest
from numba import testing
from numba.tests.support import TestCase, MemoryLeakMixin, tag

from numba.misc.quicksort import make_py_quicksort, make_jit_quicksort
from numba.misc.mergesort import make_jit_mergesort
from numba.misc.timsort import make_py_timsort, make_jit_timsort, MergeRun


def make_temp_list(keys, n):
    return [keys[0]] * n

def make_temp_array(keys, n):
    return np.empty(n, keys.dtype)


py_list_timsort = make_py_timsort(make_temp_list)

py_array_timsort = make_py_timsort(make_temp_array)

jit_list_timsort = make_jit_timsort(make_temp_list)

jit_array_timsort = make_jit_timsort(make_temp_array)

py_quicksort = make_py_quicksort()

jit_quicksort = make_jit_quicksort()


def sort_usecase(val):
    val.sort()

def argsort_usecase(val):
    return val.argsort()

def argsort_kind_usecase(val, is_stable=False):
    if is_stable:
        return val.argsort(kind='mergesort')
    else:
        return val.argsort(kind='quicksort')

def sorted_usecase(val):
    return sorted(val)

def sorted_reverse_usecase(val, b):
    return sorted(val, reverse=b)

def np_sort_usecase(val):
    return np.sort(val)

def np_argsort_usecase(val):
    return np.argsort(val)

def np_argsort_kind_usecase(val, is_stable=False):
    if is_stable:
        return np.argsort(val, kind='mergesort')
    else:
        return np.argsort(val, kind='quicksort')

def list_sort_usecase(n):
    np.random.seed(42)
    l = []
    for i in range(n):
        l.append(np.random.random())
    ll = l[:]
    ll.sort()
    return l, ll

def list_sort_reverse_usecase(n, b):
    np.random.seed(42)
    l = []
    for i in range(n):
        l.append(np.random.random())
    ll = l[:]
    ll.sort(reverse=b)
    return l, ll


class BaseSortingTest(object):

    def random_list(self, n, offset=10):
        random.seed(42)
        l = list(range(offset, offset + n))
        random.shuffle(l)
        return l

    def sorted_list(self, n, offset=10):
        return list(range(offset, offset + n))

    def revsorted_list(self, n, offset=10):
        return list(range(offset, offset + n))[::-1]

    def initially_sorted_list(self, n, m=None, offset=10):
        if m is None:
            m = n // 2
        l = self.sorted_list(m, offset)
        l += self.random_list(n - m, offset=l[-1] + offset)
        return l

    def duprandom_list(self, n, factor=None, offset=10):
        random.seed(42)
        if factor is None:
            factor = int(math.sqrt(n))
        l = (list(range(offset, offset + (n // factor) + 1)) * (factor + 1))[:n]
        assert len(l) == n
        random.shuffle(l)
        return l

    def dupsorted_list(self, n, factor=None, offset=10):
        if factor is None:
            factor = int(math.sqrt(n))
        l = (list(range(offset, offset + (n // factor) + 1)) * (factor + 1))[:n]
        assert len(l) == n, (len(l), n)
        l.sort()
        return l

    def assertSorted(self, orig, result):
        self.assertEqual(len(result), len(orig))
        # sorted() returns a list, so make sure we compare to another list
        self.assertEqual(list(result), sorted(orig))

    def assertSortedValues(self, orig, orig_values, result, result_values):
        self.assertEqual(len(result), len(orig))
        self.assertEqual(list(result), sorted(orig))
        zip_sorted = sorted(zip(orig, orig_values), key=lambda x: x[0])
        zip_result = list(zip(result, result_values))
        self.assertEqual(zip_sorted, zip_result)
        # Check stability
        for i in range(len(zip_result) - 1):
            (k1, v1), (k2, v2) = zip_result[i], zip_result[i + 1]
            if k1 == k2:
                # Assuming values are unique, which is enforced by the tests
                self.assertLess(orig_values.index(v1), orig_values.index(v2))

    def fibo(self):
        a = 1
        b = 1
        while True:
            yield a
            a, b = b, a + b

    def make_sample_sorted_lists(self, n):
        lists = []
        for offset in (20, 120):
            lists.append(self.sorted_list(n, offset))
            lists.append(self.dupsorted_list(n, offset))
        return lists

    def make_sample_lists(self, n):
        lists = []
        for offset in (20, 120):
            lists.append(self.sorted_list(n, offset))
            lists.append(self.dupsorted_list(n, offset))
            lists.append(self.revsorted_list(n, offset))
            lists.append(self.duprandom_list(n, offset))
        return lists


class BaseTimsortTest(BaseSortingTest):

    def merge_init(self, keys):
        f = self.timsort.merge_init
        return f(keys)

    def test_binarysort(self):
        n = 20
        def check(l, n, start=0):
            res = self.array_factory(l)
            f(res, res, 0, n, start)
            self.assertSorted(l, res)

        f = self.timsort.binarysort
        l = self.sorted_list(n)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_binarysort_with_values(self):
        n = 20
        v = list(range(100, 100+n))

        def check(l, n, start=0):
            res = self.array_factory(l)
            res_v = self.array_factory(v)
            f(res, res_v, 0, n, start)
            self.assertSortedValues(l, v, res, res_v)

        f = self.timsort.binarysort
        l = self.sorted_list(n)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_count_run(self):
        n = 16
        f = self.timsort.count_run

        def check(l, lo, hi):
            n, desc = f(self.array_factory(l), lo, hi)
            # Fully check invariants
            if desc:
                for k in range(lo, lo + n - 1):
                    a, b = l[k], l[k + 1]
                    self.assertGreater(a, b)
                if lo + n < hi:
                    self.assertLessEqual(l[lo + n - 1], l[lo + n])
            else:
                for k in range(lo, lo + n - 1):
                    a, b = l[k], l[k + 1]
                    self.assertLessEqual(a, b)
                if lo + n < hi:
                    self.assertGreater(l[lo + n - 1], l[lo + n], l)


        l = self.sorted_list(n, offset=100)
        check(l, 0, n)
        check(l, 1, n - 1)
        check(l, 1, 2)
        l = self.revsorted_list(n, offset=100)
        check(l, 0, n)
        check(l, 1, n - 1)
        check(l, 1, 2)
        l = self.random_list(n, offset=100)
        for i in range(len(l) - 1):
            check(l, i, n)
        l = self.duprandom_list(n, offset=100)
        for i in range(len(l) - 1):
            check(l, i, n)

    def test_gallop_left(self):
        n = 20
        f = self.timsort.gallop_left

        def check(l, key, start, stop, hint):
            k = f(key, l, start, stop, hint)
            # Fully check invariants
            self.assertGreaterEqual(k, start)
            self.assertLessEqual(k, stop)
            if k > start:
                self.assertLess(l[k - 1], key)
            if k < stop:
                self.assertGreaterEqual(l[k], key)

        def check_all_hints(l, key, start, stop):
            for hint in range(start, stop):
                check(l, key, start, stop, hint)

        def check_sorted_list(l):
            l = self.array_factory(l)
            for key in (l[5], l[15], l[0], -1000, l[-1], 1000):
                check_all_hints(l, key, 0, n)
                check_all_hints(l, key, 1, n - 1)
                check_all_hints(l, key, 8, n - 8)

        l = self.sorted_list(n, offset=100)
        check_sorted_list(l)
        l = self.dupsorted_list(n, offset=100)
        check_sorted_list(l)

    def test_gallop_right(self):
        n = 20
        f = self.timsort.gallop_right

        def check(l, key, start, stop, hint):
            k = f(key, l, start, stop, hint)
            # Fully check invariants
            self.assertGreaterEqual(k, start)
            self.assertLessEqual(k, stop)
            if k > start:
                self.assertLessEqual(l[k - 1], key)
            if k < stop:
                self.assertGreater(l[k], key)

        def check_all_hints(l, key, start, stop):
            for hint in range(start, stop):
                check(l, key, start, stop, hint)

        def check_sorted_list(l):
            l = self.array_factory(l)
            for key in (l[5], l[15], l[0], -1000, l[-1], 1000):
                check_all_hints(l, key, 0, n)
                check_all_hints(l, key, 1, n - 1)
                check_all_hints(l, key, 8, n - 8)

        l = self.sorted_list(n, offset=100)
        check_sorted_list(l)
        l = self.dupsorted_list(n, offset=100)
        check_sorted_list(l)

    def test_merge_compute_minrun(self):
        f = self.timsort.merge_compute_minrun

        for i in range(0, 64):
            self.assertEqual(f(i), i)
        for i in range(6, 63):
            if 2**i > sys.maxsize:
                break
            self.assertEqual(f(2**i), 32)
        for i in self.fibo():
            if i < 64:
                continue
            if i >= sys.maxsize:
                break
            k = f(i)
            self.assertGreaterEqual(k, 32)
            self.assertLessEqual(k, 64)
            if i > 500:
                # i/k is close to, but strictly less than, an exact power of 2
                quot = i // k
                p = 2 ** utils.bit_length(quot)
                self.assertLess(quot, p)
                self.assertGreaterEqual(quot, 0.9 * p)

    def check_merge_lo_hi(self, func, a, b):
        na = len(a)
        nb = len(b)

        # Add sentinels at start and end, to check they weren't moved
        orig_keys = [42] + a + b + [-42]
        keys = self.array_factory(orig_keys)
        ms = self.merge_init(keys)
        ssa = 1
        ssb = ssa + na

        #new_ms = func(ms, keys, [], ssa, na, ssb, nb)
        new_ms = func(ms, keys, keys, ssa, na, ssb, nb)
        self.assertEqual(keys[0], orig_keys[0])
        self.assertEqual(keys[-1], orig_keys[-1])
        self.assertSorted(orig_keys[1:-1], keys[1:-1])
        # Check the MergeState result
        self.assertGreaterEqual(len(new_ms.keys), len(ms.keys))
        self.assertGreaterEqual(len(new_ms.values), len(ms.values))
        self.assertIs(new_ms.pending, ms.pending)
        self.assertGreaterEqual(new_ms.min_gallop, 1)

    def test_merge_lo_hi(self):
        f_lo = self.timsort.merge_lo
        f_hi = self.timsort.merge_hi

        # The larger sizes exercise galloping
        for (na, nb) in [(12, 16), (40, 40), (100, 110), (1000, 1100)]:
            for a, b in itertools.product(self.make_sample_sorted_lists(na),
                                          self.make_sample_sorted_lists(nb)):
                self.check_merge_lo_hi(f_lo, a, b)
                self.check_merge_lo_hi(f_hi, b, a)

    def check_merge_at(self, a, b):
        f = self.timsort.merge_at
        # Prepare the array to be sorted
        na = len(a)
        nb = len(b)
        # Add sentinels at start and end, to check they weren't moved
        orig_keys = [42] + a + b + [-42]
        ssa = 1
        ssb = ssa + na

        stack_sentinel = MergeRun(-42, -42)

        def run_merge_at(ms, keys, i):
            new_ms = f(ms, keys, keys, i)
            self.assertEqual(keys[0], orig_keys[0])
            self.assertEqual(keys[-1], orig_keys[-1])
            self.assertSorted(orig_keys[1:-1], keys[1:-1])
            # Check stack state
            self.assertIs(new_ms.pending, ms.pending)
            self.assertEqual(ms.pending[i], (ssa, na + nb))
            self.assertEqual(ms.pending[0], stack_sentinel)
            return new_ms

        # First check with i == len(stack) - 2
        keys = self.array_factory(orig_keys)
        ms = self.merge_init(keys)
        # Push sentinel on stack, to check it wasn't touched
        ms = self.timsort.merge_append(ms, stack_sentinel)
        i = ms.n
        ms = self.timsort.merge_append(ms, MergeRun(ssa, na))
        ms = self.timsort.merge_append(ms, MergeRun(ssb, nb))
        ms = run_merge_at(ms, keys, i)
        self.assertEqual(ms.n, i + 1)

        # Now check with i == len(stack) - 3
        keys = self.array_factory(orig_keys)
        ms = self.merge_init(keys)
        # Push sentinel on stack, to check it wasn't touched
        ms = self.timsort.merge_append(ms, stack_sentinel)
        i = ms.n
        ms = self.timsort.merge_append(ms, MergeRun(ssa, na))
        ms = self.timsort.merge_append(ms, MergeRun(ssb, nb))
        # A last run (trivial here)
        last_run = MergeRun(ssb + nb, 1)
        ms = self.timsort.merge_append(ms, last_run)
        ms = run_merge_at(ms, keys, i)
        self.assertEqual(ms.n, i + 2)
        self.assertEqual(ms.pending[ms.n - 1], last_run)

    def test_merge_at(self):
        # The larger sizes exercise galloping
        for (na, nb) in [(12, 16), (40, 40), (100, 110), (500, 510)]:
            for a, b in itertools.product(self.make_sample_sorted_lists(na),
                                          self.make_sample_sorted_lists(nb)):
                self.check_merge_at(a, b)
                self.check_merge_at(b, a)

    def test_merge_force_collapse(self):
        f = self.timsort.merge_force_collapse

        # Test with runs of ascending sizes, then descending sizes
        sizes_list = [(8, 10, 15, 20)]
        sizes_list.append(sizes_list[0][::-1])

        for sizes in sizes_list:
            for chunks in itertools.product(*(self.make_sample_sorted_lists(n)
                                              for n in sizes)):
                # Create runs of the given sizes
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                ms = self.merge_init(keys)
                pos = 0
                for c in chunks:
                    ms = self.timsort.merge_append(ms, MergeRun(pos, len(c)))
                    pos += len(c)
                # Sanity check
                self.assertEqual(sum(ms.pending[ms.n - 1]), len(keys))
                # Now merge the runs
                ms = f(ms, keys, keys)
                # Remaining run is the whole list
                self.assertEqual(ms.n, 1)
                self.assertEqual(ms.pending[0], MergeRun(0, len(keys)))
                # The list is now sorted
                self.assertSorted(orig_keys, keys)

    def test_run_timsort(self):
        f = self.timsort.run_timsort

        for size_factor in (1, 10):
            # Make lists to be sorted from three chunks of different kinds.
            sizes = (15, 30, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                f(keys)
                # The list is now sorted
                self.assertSorted(orig_keys, keys)

    def test_run_timsort_with_values(self):
        # Run timsort, but also with a values array
        f = self.timsort.run_timsort_with_values

        for size_factor in (1, 5):
            chunk_size = 80 * size_factor
            a = self.dupsorted_list(chunk_size)
            b = self.duprandom_list(chunk_size)
            c = self.revsorted_list(chunk_size)
            orig_keys = a + b + c
            orig_values = list(range(1000, 1000 + len(orig_keys)))

            keys = self.array_factory(orig_keys)
            values = self.array_factory(orig_values)
            f(keys, values)
            # This checks sort stability
            self.assertSortedValues(orig_keys, orig_values, keys, values)


class TestTimsortPurePython(BaseTimsortTest, TestCase):

    timsort = py_list_timsort

    # Much faster than a Numpy array in pure Python
    array_factory = list


class TestTimsortArraysPurePython(BaseTimsortTest, TestCase):

    timsort = py_array_timsort

    def array_factory(self, lst):
        return np.array(lst, dtype=np.int32)


class JITTimsortMixin(object):

    timsort = jit_array_timsort

    test_merge_at = None
    test_merge_force_collapse = None

    def wrap_with_mergestate(self, timsort, func, _cache={}):
        """
        Wrap *func* into another compiled function inserting a runtime-created
        mergestate as the first function argument.
        """
        key = timsort, func
        if key in _cache:
            return _cache[key]

        merge_init = timsort.merge_init

        @timsort.compile
        def wrapper(keys, values, *args):
            ms = merge_init(keys)
            res = func(ms, keys, values, *args)
            return res

        _cache[key] = wrapper
        return wrapper


class TestTimsortArrays(JITTimsortMixin, BaseTimsortTest, TestCase):

    def array_factory(self, lst):
        return np.array(lst, dtype=np.int32)

    def check_merge_lo_hi(self, func, a, b):
        na = len(a)
        nb = len(b)

        func = self.wrap_with_mergestate(self.timsort, func)

        # Add sentinels at start and end, to check they weren't moved
        orig_keys = [42] + a + b + [-42]
        keys = self.array_factory(orig_keys)
        ssa = 1
        ssb = ssa + na

        new_ms = func(keys, keys, ssa, na, ssb, nb)
        self.assertEqual(keys[0], orig_keys[0])
        self.assertEqual(keys[-1], orig_keys[-1])
        self.assertSorted(orig_keys[1:-1], keys[1:-1])



class BaseQuicksortTest(BaseSortingTest):

    # NOTE these tests assume a non-argsort quicksort.

    def test_insertion_sort(self):
        n = 20
        def check(l, n):
            res = self.array_factory([9999] + l + [-9999])
            f(res, res, 1, n)
            self.assertEqual(res[0], 9999)
            self.assertEqual(res[-1], -9999)
            self.assertSorted(l, res[1:-1])

        f = self.quicksort.insertion_sort
        l = self.sorted_list(n)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_partition(self):
        n = 20
        def check(l, n):
            res = self.array_factory([9999] + l + [-9999])
            index = f(res, res, 1, n)
            self.assertEqual(res[0], 9999)
            self.assertEqual(res[-1], -9999)
            pivot = res[index]
            for i in range(1, index):
                self.assertLessEqual(res[i], pivot)
            for i in range(index + 1, n):
                self.assertGreaterEqual(res[i], pivot)

        f = self.quicksort.partition
        l = self.sorted_list(n)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_partition3(self):
        # Test the unused partition3() function
        n = 20
        def check(l, n):
            res = self.array_factory([9999] + l + [-9999])
            lt, gt = f(res, 1, n)
            self.assertEqual(res[0], 9999)
            self.assertEqual(res[-1], -9999)
            pivot = res[lt]
            for i in range(1, lt):
                self.assertLessEqual(res[i], pivot)
            for i in range(lt, gt + 1):
                self.assertEqual(res[i], pivot)
            for i in range(gt + 1, n):
                self.assertGreater(res[i], pivot)

        f = self.quicksort.partition3
        l = self.sorted_list(n)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_run_quicksort(self):
        f = self.quicksort.run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                f(keys)
                # The list is now sorted
                self.assertSorted(orig_keys, keys)

    def test_run_quicksort_lt(self):
        def lt(a, b):
            return a > b

        f = self.make_quicksort(lt=lt).run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                f(keys)
                # The list is now rev-sorted
                self.assertSorted(orig_keys, keys[::-1])

        # An imperfect comparison function, as LT(a, b) does not imply not LT(b, a).
        # The sort should handle it gracefully.
        def lt_floats(a, b):
            return math.isnan(b) or a < b

        f = self.make_quicksort(lt=lt_floats).run_quicksort

        np.random.seed(42)
        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            orig_keys = list(orig)
            keys = self.array_factory(orig_keys)
            f(keys)
            non_nans = orig[~np.isnan(orig)]
            # Non-NaNs are sorted at the front
            self.assertSorted(non_nans, keys[:len(non_nans)])


class TestQuicksortPurePython(BaseQuicksortTest, TestCase):

    quicksort = py_quicksort
    make_quicksort = staticmethod(make_py_quicksort)

    # Much faster than a Numpy array in pure Python
    array_factory = list


class TestQuicksortArrays(BaseQuicksortTest, TestCase):

    quicksort = jit_quicksort
    make_quicksort = staticmethod(make_jit_quicksort)

    def array_factory(self, lst):
        return np.array(lst, dtype=np.float64)

class TestQuicksortMultidimensionalArrays(BaseSortingTest, TestCase):

    quicksort = make_jit_quicksort(is_np_array=True)
    make_quicksort = staticmethod(make_jit_quicksort)

    def assertSorted(self, orig, result):
        self.assertEqual(orig.shape, result.shape)
        self.assertPreciseEqual(orig, result)

    def array_factory(self, lst, shape=None):
        array = np.array(lst, dtype=np.float64)
        if shape is None:
            return array.reshape(-1, array.shape[0])
        else:
            return array.reshape(shape)

    def get_shapes(self, n):
        shapes = []
        if n == 1:
            return shapes

        for i in range(2, int(math.sqrt(n)) + 1):
            if n % i == 0:
                shapes.append((n // i, i))
                shapes.append((i, n // i))
                _shapes = self.get_shapes(n // i)
                for _shape in _shapes:
                    shapes.append((i,) + _shape)
                    shapes.append(_shape + (i,))

        return shapes

    def test_run_quicksort(self):
        f = self.quicksort.run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                shape_list = self.get_shapes(len(orig_keys))
                shape_list.append(None)
                for shape in shape_list:
                    keys = self.array_factory(orig_keys, shape=shape)
                    keys_copy = self.array_factory(orig_keys, shape=shape)
                    f(keys)
                    keys_copy.sort()
                    # The list is now sorted
                    self.assertSorted(keys_copy, keys)

    def test_run_quicksort_lt(self):
        def lt(a, b):
            return a > b

        f = self.make_quicksort(lt=lt, is_np_array=True).run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                shape_list = self.get_shapes(len(orig_keys))
                shape_list.append(None)
                for shape in shape_list:
                    keys = self.array_factory(orig_keys, shape=shape)
                    keys_copy = -self.array_factory(orig_keys, shape=shape)
                    f(keys)
                    # The list is now rev-sorted
                    keys_copy.sort()
                    keys_copy = -keys_copy
                    self.assertSorted(keys_copy, keys)

        # An imperfect comparison function, as LT(a, b) does not imply not LT(b, a).
        # The sort should handle it gracefully.
        def lt_floats(a, b):
            return math.isnan(b) or a < b

        f = self.make_quicksort(lt=lt_floats, is_np_array=True).run_quicksort

        np.random.seed(42)
        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            orig_keys = list(orig)
            shape_list = self.get_shapes(len(orig_keys))
            shape_list.append(None)
            for shape in shape_list:
                keys = self.array_factory(orig_keys, shape=shape)
                keys_copy = self.array_factory(orig_keys, shape=shape)
                f(keys)
                keys_copy.sort()
                # Non-NaNs are sorted at the front
                self.assertSorted(keys_copy, keys)

class TestNumpySort(TestCase):

    def setUp(self):
        np.random.seed(42)

    def int_arrays(self):
        for size in (5, 20, 50, 500):
            yield np.random.randint(99, size=size)

    def float_arrays(self):
        for size in (5, 20, 50, 500):
            yield np.random.random(size=size) * 100
        # Now with NaNs.  Numpy sorts them at the end.
        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            yield orig
        # 90% of values are NaNs.
        for size in (50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.9] = float('nan')
            yield orig

    def has_duplicates(self, arr):
        """
        Whether the array has duplicates.  Takes NaNs into account.
        """
        if np.count_nonzero(np.isnan(arr)) > 1:
            return True
        if np.unique(arr).size < arr.size:
            return True
        return False

    def check_sort_inplace(self, pyfunc, cfunc, val):
        expected = copy.copy(val)
        got = copy.copy(val)
        pyfunc(expected)
        cfunc(got)
        self.assertPreciseEqual(got, expected)

    def check_sort_copy(self, pyfunc, cfunc, val):
        orig = copy.copy(val)
        expected = pyfunc(val)
        got = cfunc(val)
        self.assertPreciseEqual(got, expected)
        # The original wasn't mutated
        self.assertPreciseEqual(val, orig)

    def check_argsort(self, pyfunc, cfunc, val, kwargs={}):
        orig = copy.copy(val)
        expected = pyfunc(val, **kwargs)
        got = cfunc(val, **kwargs)
        self.assertPreciseEqual(orig[got], np.sort(orig),
                                msg="the array wasn't argsorted")
        # Numba and Numpy results may differ if there are duplicates
        # in the array
        if not self.has_duplicates(orig):
            self.assertPreciseEqual(got, expected)
        # The original wasn't mutated
        self.assertPreciseEqual(val, orig)

    def test_array_sort_int(self):
        pyfunc = sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for orig in self.int_arrays():
            self.check_sort_inplace(pyfunc, cfunc, orig)

    def test_array_sort_float(self):
        pyfunc = sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for orig in self.float_arrays():
            self.check_sort_inplace(pyfunc, cfunc, orig)

    def test_np_sort_int(self):
        pyfunc = np_sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for orig in self.int_arrays():
            self.check_sort_copy(pyfunc, cfunc, orig)

    def test_np_sort_float(self):
        pyfunc = np_sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            self.check_sort_copy(pyfunc, cfunc, orig)

    def test_argsort_int(self):
        def check(pyfunc):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.int_arrays():
                self.check_argsort(pyfunc, cfunc, orig)

        check(argsort_usecase)
        check(np_argsort_usecase)

    def test_argsort_kind_int(self):
        def check(pyfunc, is_stable):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.int_arrays():
                self.check_argsort(pyfunc, cfunc, orig,
                                   dict(is_stable=is_stable))

        check(argsort_kind_usecase, is_stable=True)
        check(np_argsort_kind_usecase, is_stable=True)
        check(argsort_kind_usecase, is_stable=False)
        check(np_argsort_kind_usecase, is_stable=False)

    def test_argsort_float(self):
        def check(pyfunc):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.float_arrays():
                self.check_argsort(pyfunc, cfunc, orig)

        check(argsort_usecase)
        check(np_argsort_usecase)

    def test_argsort_float(self):
        def check(pyfunc, is_stable):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.float_arrays():
                self.check_argsort(pyfunc, cfunc, orig,
                                   dict(is_stable=is_stable))

        check(argsort_kind_usecase, is_stable=True)
        check(np_argsort_kind_usecase, is_stable=True)
        check(argsort_kind_usecase, is_stable=False)
        check(np_argsort_kind_usecase, is_stable=False)

    def test_bad_array(self):
        cfunc = jit(nopython=True)(np_sort_usecase)
        msg = '.*Argument "a" must be array-like.*'
        with self.assertRaisesRegex(errors.TypingError, msg) as raises:
            cfunc(None)


class TestPythonSort(TestCase):

    def test_list_sort(self):
        pyfunc = list_sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (20, 50, 500):
            orig, ret = cfunc(size)
            self.assertEqual(sorted(orig), ret)
            self.assertNotEqual(orig, ret)   # sanity check

    def test_list_sort_reverse(self):
        pyfunc = list_sort_reverse_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (20, 50, 500):
            for b in (False, True):
                orig, ret = cfunc(size, b)
                self.assertEqual(sorted(orig, reverse=b), ret)
                self.assertNotEqual(orig, ret)   # sanity check

    def test_sorted(self):
        pyfunc = sorted_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (20, 50, 500):
            orig = np.random.random(size=size) * 100
            expected = sorted(orig)
            got = cfunc(orig)
            self.assertPreciseEqual(got, expected)
            self.assertNotEqual(list(orig), got)   # sanity check

    def test_sorted_reverse(self):
        pyfunc = sorted_reverse_usecase
        cfunc = jit(nopython=True)(pyfunc)
        size = 20

        orig = np.random.random(size=size) * 100
        for b in (False, True):
            expected = sorted(orig, reverse=b)
            got = cfunc(orig, b)
            self.assertPreciseEqual(got, expected)
            self.assertNotEqual(list(orig), got)   # sanity check


class TestMergeSort(TestCase):
    def setUp(self):
        np.random.seed(321)

    def check_argsort_stable(self, sorter, low, high, count):
        # make data with high possibility of duplicated key
        data = np.random.randint(low, high, count)
        expect = np.argsort(data, kind='mergesort')
        got = sorter(data)
        np.testing.assert_equal(expect, got)

    def test_argsort_stable(self):
        arglist = [
            (-2, 2, 5),
            (-5, 5, 10),
            (0, 10, 101),
            (0, 100, 1003),
        ]
        imp = make_jit_mergesort(is_argsort=True)
        toplevel = imp.run_mergesort
        sorter = njit(lambda arr: toplevel(arr))
        for args in arglist:
            self.check_argsort_stable(sorter, *args)


nop_compiler = lambda x:x


class TestSortSlashSortedWithKey(MemoryLeakMixin, TestCase):

    def test_01(self):

        a = [3, 1, 4, 1, 5, 9]

        @njit
        def external_key(z):
            return 1. / z

        @njit
        def foo(x, key=None):
            new_x = x[:]
            new_x.sort(key=key)
            return sorted(x[:], key=key), new_x

        self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))
        self.assertPreciseEqual(foo(a[:], external_key),
                                foo.py_func(a[:], external_key))

    def test_02(self):

        a = [3, 1, 4, 1, 5, 9]

        @njit
        def foo(x):
            def closure_key(z):
                return 1. / z
            new_x = x[:]
            new_x.sort(key=closure_key)
            return sorted(x[:], key=closure_key), new_x

        self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))

    def test_03(self):

        a = [3, 1, 4, 1, 5, 9]

        def gen(compiler):

            @compiler
            def bar(x, func):
                new_x = x[:]
                new_x.sort(key=func)
                return sorted(x[:], key=func), new_x

            @compiler
            def foo(x):
                def closure_escapee_key(z):
                    return 1. / z
                return bar(x, closure_escapee_key)

            return foo

        self.assertPreciseEqual(gen(njit)(a[:]), gen(nop_compiler)(a[:]))

    def test_04(self):

        a = ['a','b','B','b','C','A']

        @njit
        def external_key(z):
            return z.upper()

        @njit
        def foo(x, key=None):
            new_x = x[:]
            new_x.sort(key=key)
            return sorted(x[:], key=key), new_x

        self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))
        self.assertPreciseEqual(foo(a[:], external_key),
                                foo.py_func(a[:], external_key))

    def test_05(self):

        a = ['a','b','B','b','C','A']

        @njit
        def external_key(z):
            return z.upper()

        @njit
        def foo(x, key=None, reverse=False):
            new_x = x[:]
            new_x.sort(key=key, reverse=reverse)
            return (sorted(x[:], key=key, reverse=reverse), new_x)

        for key, rev in itertools.product((None, external_key),
                                          (True, False, 1, -12, 0)):
            self.assertPreciseEqual(foo(a[:], key, rev),
                                    foo.py_func(a[:], key, rev))

    def test_optional_on_key(self):
        a = [3, 1, 4, 1, 5, 9]

        @njit
        def foo(x, predicate):
            if predicate:
                def closure_key(z):
                    return 1. / z
            else:
                closure_key = None

            new_x = x[:]
            new_x.sort(key=closure_key)

            return (sorted(x[:], key=closure_key), new_x)

        with self.assertRaises(errors.TypingError) as raises:
            TF = True
            foo(a[:], TF)

        msg = "Key must concretely be None or a Numba JIT compiled function"
        self.assertIn(msg, str(raises.exception))

    def test_exceptions_sorted(self):

        @njit
        def foo_sorted(x, key=None, reverse=False):
            return sorted(x[:], key=key, reverse=reverse)

        @njit
        def foo_sort(x, key=None, reverse=False):
            new_x = x[:]
            new_x.sort(key=key, reverse=reverse)
            return new_x

        @njit
        def external_key(z):
            return 1. / z

        a = [3, 1, 4, 1, 5, 9]

        for impl in (foo_sort, foo_sorted):

            # check illegal key
            with self.assertRaises(errors.TypingError) as raises:
                impl(a, key="illegal")

            expect = "Key must be None or a Numba JIT compiled function"
            self.assertIn(expect, str(raises.exception))

            # check illegal reverse
            with self.assertRaises(errors.TypingError) as raises:
                impl(a, key=external_key, reverse="go backwards")

            expect = "an integer is required for 'reverse'"
            self.assertIn(expect, str(raises.exception))


class TestArrayArgsort(MemoryLeakMixin, TestCase):
    """Tests specific to array.argsort"""

    def test_exceptions(self):

        @njit
        def nonliteral_kind(kind):
            np.arange(5).argsort(kind=kind)

        # check non-literal kind
        with self.assertRaises(errors.TypingError) as raises:
            # valid spelling but not literal
            nonliteral_kind('quicksort')

        expect = '"kind" must be a string literal'
        self.assertIn(expect, str(raises.exception))

        @njit
        def unsupported_kwarg():
            np.arange(5).argsort(foo='')

        with self.assertRaises(errors.TypingError) as raises:
            unsupported_kwarg()

        expect = "Unsupported keywords: ['foo']"
        self.assertIn(expect, str(raises.exception))


if __name__ == '__main__':
    unittest.main()
