#!/usr/bin/env python3

from unittest import TestCase, main

import numpy as np
from numpy.testing import assert_array_equal

from vbz import compress, compression_options, decompress


class Tests:
    vbz_version = None

    """ Base test functions for all 3 encoding/decoding functions """

    def test_decode(self):
        """S1: simple decode"""
        options = None
        if self.vbz_version:
            zigzag = np.issubdtype(self.data.dtype, np.signedinteger)
            options = compression_options(
                zigzag, self.data.dtype.itemsize, 1, self.vbz_version
            )

        res = compress(self.data, options)
        rec = decompress(res, self.dtype, options)
        assert_array_equal(self.data, rec)


class BasicTest(Tests):
    """Base test for encoding and decoding with small simple test case"""

    def setUp(self):
        self.data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=self.dtype)


class RandTest(Tests):
    """Base test encoding and decoding with large random array"""

    def setUp(self):
        self.data = np.random.randint(
            self.rmin, self.rmax, size=self.size, dtype=self.dtype
        )


# Default configuration test (default options)
class BasicTestInt8(BasicTest, TestCase):
    dtype = np.int8


class BasicTestUInt8(BasicTest, TestCase):
    dtype = np.uint8


class BasicTestInt16(BasicTest, TestCase):
    dtype = np.int16


class BasicTestUInt16(BasicTest, TestCase):
    dtype = np.uint16


class BasicTestInt32(BasicTest, TestCase):
    dtype = np.int32


class BasicTestUInt32(BasicTest, TestCase):
    dtype = np.uint32


# Using v1 (half support test)
class BasicV1TestInt8(BasicTestInt8):
    vbz_version = 1


class BasicV1TestUInt8(BasicTestUInt8):
    vbz_version = 1


class BasicV1TestInt16(BasicTestInt16):
    vbz_version = 1


class BasicV1TestUInt16(BasicTestUInt16):
    vbz_version = 1


# Default configuration test (default options)
class RandTestInt8(RandTest, TestCase):
    dtype = np.int8
    rmin = -(2**7)
    rmax = 2**7 - 1
    size = 200000


class RandTestUInt8(RandTest, TestCase):
    dtype = np.uint8
    rmin = 0
    rmax = 2**7 - 1
    size = 200000


class RandTestInt16(RandTest, TestCase):
    dtype = np.int16
    rmin = -(2**15)
    rmax = 2**15 - 1
    size = 200000


class RandTestUInt16(RandTest, TestCase):
    dtype = np.uint16
    rmin = 0
    rmax = 2**15 - 1
    size = 200000


class RandTestInt32(RandTest, TestCase):
    dtype = np.int32
    rmin = -(2**31)
    rmax = 2**31 - 1
    size = 200000


class RandTestIntU32(RandTest, TestCase):
    dtype = np.uint32
    rmin = 0
    rmax = 2**31 - 1
    size = 200000


# Random tests Using v1 (half support test)
class RandV1TestInt8(RandTestInt8):
    vbz_version = 1


class RandV1TestUInt8(RandTestUInt8):
    vbz_version = 1


class RandV1TestInt16(RandTestUInt16):
    vbz_version = 1


class RandV1TestUInt16(RandTestUInt16):
    vbz_version = 1


if __name__ == "__main__":
    main()
