# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Transforms a float-trained graph into an equivalent quantized version.
An example of command-line usage is:
bazel build tensorflow/tools/quantization:quantize_graph \
&& bazel-bin/tensorflow/tools/quantization/quantize_graph \
--input=tensorflow_inception_graph.pb
--output_node_names="softmax2" --print_nodes --output=/tmp/quantized_graph.pb \
--mode=eightbit --logtostderr
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import app
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import gfile
flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_boolean("print_nodes", False, """Lists all nodes in the model.""")
flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_string("output_node_names", "",
"""Output node names, comma separated.""")
flags.DEFINE_string("output", "", """File to save the output graph to.""")
flags.DEFINE_integer("bitdepth", 8,
"""How many bits to quantize the graph to.""")
flags.DEFINE_string("mode", "round",
"""What transformation to apply (round, quantize,"""
""" eightbit, weights, or weights_rounded).""")
flags.DEFINE_string("test_input_dims", "1,224,224,3",
"""The size of the input tensor to use when testing a"""
""" graph loaded from a file.""")
flags.DEFINE_boolean("strip_redundant_quantization", True,
"""Removes redundant dequantize/quantize pairs.""")
flags.DEFINE_boolean("quantized_input", False,
"If true, assume Placeholders are quantized with values "
"covering [--quantized_input_min,--quantized_input_max]. "
"Only supported when --mode=eightbit")
flags.DEFINE_float("quantized_input_min", 0,
"The minimum of the actual input range when "
"--quantized_input")
flags.DEFINE_float("quantized_input_max", 1,
"The maximum of the actual input range when "
"--quantized_input")
flags.DEFINE_float(
"quantized_fallback_min", None,
"The fallback 'min' value to use for layers which lack min-max "
"information. Note: this should be considered a coarse tool just good "
"enough for experimentation purposes, since graphs quantized in this way "
"would be very inaccurate.")
flags.DEFINE_float(
"quantized_fallback_max", None,
"The fallback 'max' value to use for layers which lack min-max "
"information. Note: this should be considered a coarse tool just good "
"enough for experimentation purposes, since graphs quantized in this way "
"would be very inaccurate.")
def print_input_nodes(current_node, nodes_map, indent, already_visited):
print(" " * indent + current_node.op + ":" + current_node.name)
already_visited[current_node.name] = True
for input_node_name in current_node.input:
if input_node_name in already_visited:
continue
input_node = nodes_map[input_node_name]
print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
def create_node(op, name, inputs):
new_node = node_def_pb2.NodeDef()
new_node.op = op
new_node.name = name
for input_name in inputs:
new_node.input.extend([input_name])
return new_node
def create_constant_node(name, value, dtype, shape=None):
node = create_node("Const", name, [])
set_attr_dtype(node, "dtype", dtype)
set_attr_tensor(node, "value", value, dtype, shape)
return node
def copy_attr(node, key, attr_value):
try:
node.attr[key].CopyFrom(attr_value)
except KeyError:
pass
def set_attr_dtype(node, key, value):
try:
node.attr[key].CopyFrom(
attr_value_pb2.AttrValue(type=value.as_datatype_enum))
except KeyError:
pass
def set_attr_shape(node, key, value):
try:
node.attr[key].CopyFrom(
attr_value_pb2.AttrValue(shape=tensor_shape.as_shape(value).as_proto()))
except KeyError:
pass
def set_attr_tensor(node, key, value, dtype, shape=None):
try:
node.attr[key].CopyFrom(
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
value, dtype=dtype, shape=shape)))
except KeyError:
pass
def set_attr_string(node, key, value):
try:
node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value))
except KeyError:
pass
def set_attr_int_list(node, key, value):
list_value = attr_value_pb2.AttrValue.ListValue(i=value)
try:
node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
except KeyError:
pass
def set_attr_bool(node, key, value):
try:
node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value))
except KeyError:
pass
def set_attr_int(node, key, value):
try:
node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value))
except KeyError:
pass
def set_attr_float(node, key, value):
try:
node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value))
except KeyError:
pass
def node_name_from_input(node_name):
"""Strips off ports and other decorations to get the underlying node name."""
if node_name.startswith("^"):
node_name = node_name[1:]
m = re.search(r"(.*):\d+$", node_name)
if m:
node_name = m.group(1)
return node_name
def ensure_tensor_name_has_port(node_name):
"""Makes sure that a tensor name has :0 if no explicit port exists."""
m = re.search(r"(.*):\d+$", node_name)
if m:
name_with_port = node_name
else:
name_with_port = node_name + ":0"
return name_with_port
def unique_node_name_from_input(node_name):
"""Replaces invalid characters in input names to get a unique node name."""
return node_name.replace(":", "__port__").replace("^", "__hat__")
def quantize_array(arr, num_buckets):
"""Quantizes a numpy array.
This function maps each scalar in arr to the center of one of num_buckets
buckets. For instance,
quantize_array([0, 0.3, 0.6, 1], 2) => [0.25, 0.25, 0.75, 0.75]
Args:
arr: The numpy array to quantize.
num_buckets: The number of buckets to map "var" to.
Returns:
The quantized numpy array.
Raises:
ValueError: when num_buckets < 1.
"""
if num_buckets < 1:
raise ValueError("num_buckets must be >= 1")
arr_max = arr.max()
arr_min = arr.min()
if arr_max == arr_min:
return arr
bucket_width = (arr_max - arr_min) / num_buckets
# Map scalars to bucket indices. Take special care of max(arr).
bucket_indices = np.floor((arr - arr_min) / bucket_width)
bucket_indices[bucket_indices == num_buckets] = num_buckets - 1
# Map e
吃不胖的卷卷
- 粉丝: 91
- 资源: 12
最新资源
- sensors-18-03721.pdf
- Facebook.apk
- 推荐一款JTools的call-this-method插件
- json的合法基色来自红包东i请各位
- 项目采用YOLO V4算法模型进行目标检测,使用Deep SORT目标跟踪算法 .zip
- 针对实时视频流和静态图像实现的对象检测和跟踪算法 .zip
- 部署 yolox 算法使用 deepstream.zip
- 基于webmagic、springboot和mybatis的MagicToe Java爬虫设计源码
- 通过实时流协议 (RTSP) 使用 Yolo、OpenCV 和 Python 进行深度学习的对象检测.zip
- 基于Python和HTML的tb商品列表查询分析设计源码
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈