File: randomize_weights.py

package info (click to toggle)
chromium 90.0.4430.212-1~deb10u1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 3,450,632 kB
  • sloc: cpp: 19,832,434; javascript: 2,948,838; ansic: 2,312,399; python: 1,464,622; xml: 584,121; java: 514,189; asm: 470,557; objc: 83,463; perl: 77,861; sh: 77,030; cs: 70,789; fortran: 24,137; tcl: 18,916; php: 18,872; makefile: 16,848; ruby: 16,721; pascal: 13,150; sql: 10,199; yacc: 7,507; lex: 1,313; lisp: 840; awk: 329; jsp: 39; sed: 19
file content (64 lines) | stat: -rw-r--r-- 2,049 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Randomize all weights in a tflite file.

Example usage:
python randomize_weights.py \
  --input_tflite_file=foo.tflite \
  --output_tflite_file=foo_randomized.tflite
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.python.platform import app


def main(_):
  parser = argparse.ArgumentParser(
      description='Randomize weights in a tflite file.')
  parser.add_argument(
      '--input_tflite_file',
      type=str,
      required=True,
      help='Full path name to the input tflite file.')
  parser.add_argument(
      '--output_tflite_file',
      type=str,
      required=True,
      help='Full path name to the output randomized tflite file.')
  parser.add_argument(
      '--random_seed',
      type=str,
      required=False,
      default=0,
      help='Input to the random number generator. The default value is 0.')
  args = parser.parse_args()

  # Read the model
  model = flatbuffer_utils.read_model(args.input_tflite_file)
  # Invoke the randomize weights function
  flatbuffer_utils.randomize_weights(model, args.random_seed)
  # Write the model
  flatbuffer_utils.write_model(model, args.output_tflite_file)


if __name__ == '__main__':
  app.run(main=main, argv=sys.argv[:1])