shrink_tf_graph_weights.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # This file is part of OpenCV project.
  2. # It is subject to the license terms in the LICENSE file found in the top-level directory
  3. # of this distribution and at http://opencv.org/license.html.
  4. #
  5. # Copyright (C) 2017, Intel Corporation, all rights reserved.
  6. # Third party copyrights are property of their respective owners.
  7. import tensorflow as tf
  8. import struct
  9. import argparse
  10. import numpy as np
  11. parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.')
  12. parser.add_argument('--input', required=True, help='Path to frozen graph.')
  13. parser.add_argument('--output', required=True, help='Path to output graph.')
  14. parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+',
  15. help='List of ops which weights are converted.')
  16. args = parser.parse_args()
  17. DT_FLOAT = 1
  18. DT_HALF = 19
  19. # For the frozen graphs, an every node that uses weights connected to Const nodes
  20. # through an Identity node. Usually they're called in the same way with '/read' suffix.
  21. # We'll replace all of them to Cast nodes.
  22. # Load the model
  23. with tf.gfile.FastGFile(args.input) as f:
  24. graph_def = tf.GraphDef()
  25. graph_def.ParseFromString(f.read())
  26. # Set of all inputs from desired nodes.
  27. inputs = []
  28. for node in graph_def.node:
  29. if node.op in args.ops:
  30. inputs += node.input
  31. weightsNodes = []
  32. for node in graph_def.node:
  33. # From the whole inputs we need to keep only an Identity nodes.
  34. if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT:
  35. weightsNodes.append(node.input[0])
  36. # Replace Identity to Cast.
  37. node.op = 'Cast'
  38. node.attr['DstT'].type = DT_FLOAT
  39. node.attr['SrcT'].type = DT_HALF
  40. del node.attr['T']
  41. del node.attr['_class']
  42. # Convert weights to halfs.
  43. for node in graph_def.node:
  44. if node.name in weightsNodes:
  45. node.attr['dtype'].type = DT_HALF
  46. node.attr['value'].tensor.dtype = DT_HALF
  47. floats = node.attr['value'].tensor.tensor_content
  48. floats = struct.unpack('f' * (len(floats) / 4), floats)
  49. halfs = np.array(floats).astype(np.float16).view(np.uint16)
  50. node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)
  51. tf.train.write_graph(graph_def, "", args.output, as_text=False)