File: apot_utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (56 lines) | stat: -rw-r--r-- 1,345 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
r"""
This file contains utility functions to convert values
using APoT nonuniform quantization methods.
"""

import math

r"""Converts floating point input into APoT number
    based on quantization levels
"""
def float_to_apot(x, levels, indices, alpha):
    # clip values based on alpha
    if x < -alpha:
        return -alpha
    elif x > alpha:
        return alpha

    levels_lst = list(levels)
    indices_lst = list(indices)

    min_delta = math.inf
    best_idx = 0

    for level, idx in zip(levels_lst, indices_lst):
        cur_delta = abs(level - x)
        if cur_delta < min_delta:
            min_delta = cur_delta
            best_idx = idx

    return best_idx

r"""Converts floating point input into
    reduced precision floating point value
    based on quantization levels
"""
def quant_dequant_util(x, levels, indices):
    levels_lst = list(levels)
    indices_lst = list(indices)

    min_delta = math.inf
    best_fp = 0.0

    for level, idx in zip(levels_lst, indices_lst):
        cur_delta = abs(level - x)
        if cur_delta < min_delta:
            min_delta = cur_delta
            best_fp = level

    return best_fp

r"""Converts APoT input into floating point number
based on quantization levels
"""
def apot_to_float(x_apot, levels, indices):
    idx = list(indices).index(x_apot)
    return levels[idx]