本文共 3661 字,大约阅读时间需要 12 分钟。
本文将介绍如下内容:
TF函数中的input_signature的作用:
1,签名(确定传入的参数是符合要求的)。 2,有input_signature后,才能在tf中保存成TF图解钩-SaveModel。函数签名由函数原型组成。它告诉你的是关于函数的一般信息,它的名称,参数,它的范围以及其他杂项信息。可以确定传入的参数是符合要求的。
我们使用input_signature
对tf.function修饰的函数进行数字签名;
tf.TensorSpec ( shape, dtype=tf.dtypes.float32, name=None )
tf.TensorSpec()
:#TensorSpec: 描述一个张量。
import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 1,打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras: print(module.__name__, module.__version__)# 2,TF函数的签名机制@tf.function(input_signature=[tf.TensorSpec([None], tf.int32, name='x')])def cube(z): return tf.pow(z, 3)try: print(cube(tf.constant([1., 2., 3.])))except ValueError as ex: print(ex)print(cube(tf.constant([1, 2, 3])))#---output-------Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name='x'))tf.Tensor([ 1 8 27], shape=(3,), dtype=int32)
对于被tf.function
修饰过的函数都有get_concrete_function
的属性,可以使用@tf.function()
对函数添加函数签名,从而获取特定追踪。通过增加函数签名之后才能够将模型保存。
tf.function
修饰过的函数的get_concrete_function
属性# @tf.function py func -> tf graph# get_concrete_function -> add input signature -> SavedModelcube_func_int32 = cube.get_concrete_function(tf.TensorSpec([None], tf.int32))print(cube_func_int32)# tf.TensorSpec()中可以携带参数,且类型属性不变print(cube_func_int32 is cube.get_concrete_function(tf.TensorSpec([5], tf.int32)))print(cube_func_int32 is cube.get_concrete_function(tf.constant([1, 2, 3])))#----output------TrueTrue
print(cube_func_int32.graph)#---output-----FuncGraph(name=cube, id=140141878427488)
print(cube_func_int32.graph.get_operations())#---output-----[, , , ]
pow_op = cube_func_int32.graph.get_operations()[2]print(pow_op)#---output------------name: "Pow"op: "Pow"input: "x"input: "Pow/y"attr { key: "T" value { type: DT_INT32 }}
# 通过节点名称获取节点对象print(cube_func_int32.graph.get_operation_by_name("x"))# 通过tensor名称获取节点对象print(cube_func_int32.graph.get_tensor_by_name("x:0"))#---output-----------name: "x"op: "Placeholder"attr { key: "_user_specified_name" value { s: "x" }}attr { key: "dtype" value { type: DT_INT32 }}attr { key: "shape" value { shape { dim { size: -1 } } }}Tensor("x:0", shape=(None,), dtype=int32)
print(cube_func_int32.graph.as_graph_def())#---output-----node { name: "x" op: "Placeholder" attr { key: "_user_specified_name" value { s: "x" } } attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: -1 } } } }}node { name: "Pow/y" op: "Const" attr { key: "dtype" value { type: DT_INT32 } } attr { key: "value" value { tensor { dtype: DT_INT32 tensor_shape { } int_val: 3 } } }}node { name: "Pow" op: "Pow" input: "x" input: "Pow/y" attr { key: "T" value { type: DT_INT32 } }}node { name: "Identity" op: "Identity" input: "Pow" attr { key: "T" value { type: DT_INT32 } }}versions { producer: 175}
转载地址:http://yvili.baihongyu.com/