# Owner(s): ["oncall: quantization"]

from .common import AOMigrationTestCase


class TestAOMigrationQuantization(AOMigrationTestCase):
    r"""Modules and functions related to the
    `torch/quantization` migration to `torch/ao/quantization`.
    """
    def test_package_import_quantize(self):
        self._test_package_import('quantize')

    def test_function_import_quantize(self):
        function_list = [
            '_convert',
            '_observer_forward_hook',
            '_propagate_qconfig_helper',
            '_remove_activation_post_process',
            '_remove_qconfig',
            'add_observer_',
            'add_quant_dequant',
            'convert',
            'get_observer_dict',
            'get_unique_devices_',
            'is_activation_post_process',
            'prepare',
            'prepare_qat',
            'propagate_qconfig_',
            'quantize',
            'quantize_dynamic',
            'quantize_qat',
            'register_activation_post_process_hook',
            'swap_module',
        ]
        self._test_function_import('quantize', function_list)

    def test_package_import_stubs(self):
        self._test_package_import('stubs')

    def test_function_import_stubs(self):
        function_list = [
            'QuantStub',
            'DeQuantStub',
            'QuantWrapper',
        ]
        self._test_function_import('stubs', function_list)

    def test_package_import_quantize_jit(self):
        self._test_package_import('quantize_jit')

    def test_function_import_quantize_jit(self):
        function_list = [
            '_check_is_script_module',
            '_check_forward_method',
            'script_qconfig',
            'script_qconfig_dict',
            'fuse_conv_bn_jit',
            '_prepare_jit',
            'prepare_jit',
            'prepare_dynamic_jit',
            '_convert_jit',
            'convert_jit',
            'convert_dynamic_jit',
            '_quantize_jit',
            'quantize_jit',
            'quantize_dynamic_jit',
        ]
        self._test_function_import('quantize_jit', function_list)

    def test_package_import_fake_quantize(self):
        self._test_package_import('fake_quantize')

    def test_function_import_fake_quantize(self):
        function_list = [
            '_is_per_channel',
            '_is_per_tensor',
            '_is_symmetric_quant',
            'FakeQuantizeBase',
            'FakeQuantize',
            'FixedQParamsFakeQuantize',
            'FusedMovingAvgObsFakeQuantize',
            'default_fake_quant',
            'default_weight_fake_quant',
            'default_fixed_qparams_range_neg1to1_fake_quant',
            'default_fixed_qparams_range_0to1_fake_quant',
            'default_per_channel_weight_fake_quant',
            'default_histogram_fake_quant',
            'default_fused_act_fake_quant',
            'default_fused_wt_fake_quant',
            'default_fused_per_channel_wt_fake_quant',
            '_is_fake_quant_script_module',
            'disable_fake_quant',
            'enable_fake_quant',
            'disable_observer',
            'enable_observer',
        ]
        self._test_function_import('fake_quantize', function_list)

    def test_package_import_fuse_modules(self):
        self._test_package_import('fuse_modules')

    def test_function_import_fuse_modules(self):
        function_list = [
            '_fuse_modules',
            '_get_module',
            '_set_module',
            'fuse_conv_bn',
            'fuse_conv_bn_relu',
            'fuse_known_modules',
            'fuse_modules',
            'get_fuser_method',
        ]
        self._test_function_import('fuse_modules', function_list)

    def test_package_import_quant_type(self):
        self._test_package_import('quant_type')

    def test_function_import_quant_type(self):
        function_list = [
            'QuantType',
            'quant_type_to_str',
        ]
        self._test_function_import('quant_type', function_list)

    def test_package_import_observer(self):
        self._test_package_import('observer')

    def test_function_import_observer(self):
        function_list = [
            "_PartialWrapper",
            "_with_args",
            "_with_callable_args",
            "ABC",
            "ObserverBase",
            "_ObserverBase",
            "MinMaxObserver",
            "MovingAverageMinMaxObserver",
            "PerChannelMinMaxObserver",
            "MovingAveragePerChannelMinMaxObserver",
            "HistogramObserver",
            "PlaceholderObserver",
            "RecordingObserver",
            "NoopObserver",
            "_is_activation_post_process",
            "_is_per_channel_script_obs_instance",
            "get_observer_state_dict",
            "load_observer_state_dict",
            "default_observer",
            "default_placeholder_observer",
            "default_debug_observer",
            "default_weight_observer",
            "default_histogram_observer",
            "default_per_channel_weight_observer",
            "default_dynamic_quant_observer",
            "default_float_qparams_observer",
        ]
        self._test_function_import('observer', function_list)

    def test_package_import_qconfig(self):
        self._test_package_import('qconfig')

    def test_function_import_qconfig(self):
        function_list = [
            "QConfig",
            "default_qconfig",
            "default_debug_qconfig",
            "default_per_channel_qconfig",
            "QConfigDynamic",
            "default_dynamic_qconfig",
            "float16_dynamic_qconfig",
            "float16_static_qconfig",
            "per_channel_dynamic_qconfig",
            "float_qparams_weight_only_qconfig",
            "default_qat_qconfig",
            "default_weight_only_qconfig",
            "default_activation_only_qconfig",
            "default_qat_qconfig_v2",
            "get_default_qconfig",
            "get_default_qat_qconfig",
            "assert_valid_qconfig",
            "QConfigAny",
            "add_module_to_qconfig_obs_ctr",
            "qconfig_equals"
        ]
        self._test_function_import('qconfig', function_list)

    def test_package_import_quantization_mappings(self):
        self._test_package_import('quantization_mappings')

    def test_function_import_quantization_mappings(self):
        function_list = [
            "no_observer_set",
            "get_default_static_quant_module_mappings",
            "get_static_quant_module_class",
            "get_dynamic_quant_module_class",
            "get_default_qat_module_mappings",
            "get_default_dynamic_quant_module_mappings",
            "get_default_qconfig_propagation_list",
            "get_default_compare_output_module_list",
            "get_default_float_to_quantized_operator_mappings",
            "get_quantized_operator",
            "_get_special_act_post_process",
            "_has_special_act_post_process",
        ]
        dict_list = [
            "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
            "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
            "DEFAULT_QAT_MODULE_MAPPINGS",
            "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
            # "_INCLUDE_QCONFIG_PROPAGATE_LIST",
            "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
            "DEFAULT_MODULE_TO_ACT_POST_PROCESS",
        ]
        self._test_function_import('quantization_mappings', function_list)
        self._test_dict_import('quantization_mappings', dict_list)

    def test_package_import_fuser_method_mappings(self):
        self._test_package_import('fuser_method_mappings')

    def test_function_import_fuser_method_mappings(self):
        function_list = [
            "fuse_conv_bn",
            "fuse_conv_bn_relu",
            "fuse_linear_bn",
            "get_fuser_method",
        ]
        dict_list = [
            "DEFAULT_OP_LIST_TO_FUSER_METHOD"
        ]
        self._test_function_import('fuser_method_mappings', function_list)
        self._test_dict_import('fuser_method_mappings', dict_list)

    def test_package_import_utils(self):
        self._test_package_import('utils')

    def test_function_import_utils(self):
        function_list = [
            'activation_dtype',
            'activation_is_int8_quantized',
            'activation_is_statically_quantized',
            'calculate_qmin_qmax',
            'check_min_max_valid',
            'get_combined_dict',
            'get_qconfig_dtypes',
            'get_qparam_dict',
            'get_quant_type',
            'get_swapped_custom_module_class',
            'getattr_from_fqn',
            'is_per_channel',
            'is_per_tensor',
            'weight_dtype',
            'weight_is_quantized',
            'weight_is_statically_quantized',
        ]
        self._test_function_import('utils', function_list)
