博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras(十) TF函数签名与图结构
阅读量:4202 次
发布时间:2019-05-26

本文共 3661 字,大约阅读时间需要 12 分钟。

本文将介绍如下内容:

  • TF函数签名
  • 图结构

TF函数中的input_signature的作用:

1,签名(确定传入的参数是符合要求的)。
2,有input_signature后,才能在tf中保存成TF图解钩-SaveModel。

一,TF函数签名

函数签名由函数原型组成。它告诉你的是关于函数的一般信息,它的名称,参数,它的范围以及其他杂项信息。可以确定传入的参数是符合要求的。

我们使用input_signature对tf.function修饰的函数进行数字签名;

tf.TensorSpec ( shape, dtype=tf.dtypes.float32, name=None )

tf.TensorSpec() :#TensorSpec: 描述一个张量。

import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 1,打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras:    print(module.__name__, module.__version__)# 2,TF函数的签名机制@tf.function(input_signature=[tf.TensorSpec([None], tf.int32, name='x')])def cube(z):    return tf.pow(z, 3)try:    print(cube(tf.constant([1., 2., 3.])))except ValueError as ex:    print(ex)print(cube(tf.constant([1, 2, 3])))#---output-------Python inputs incompatible with input_signature:  inputs: (    tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32))  input_signature: (    TensorSpec(shape=(None,), dtype=tf.int32, name='x'))tf.Tensor([ 1  8 27], shape=(3,), dtype=int32)

二,图结构

对于被tf.function修饰过的函数都有get_concrete_function的属性,可以使用@tf.function()对函数添加函数签名,从而获取特定追踪。通过增加函数签名之后才能够将模型保存

1,获取被tf.function修饰过的函数的get_concrete_function属性
# @tf.function py func -> tf graph# get_concrete_function -> add input signature -> SavedModelcube_func_int32 = cube.get_concrete_function(tf.TensorSpec([None], tf.int32))print(cube_func_int32)# tf.TensorSpec()中可以携带参数,且类型属性不变print(cube_func_int32 is cube.get_concrete_function(tf.TensorSpec([5], tf.int32)))print(cube_func_int32 is cube.get_concrete_function(tf.constant([1, 2, 3])))#----output------
TrueTrue
2,获取被装饰函数get_concrete_function属性中的图结构对象
print(cube_func_int32.graph)#---output-----FuncGraph(name=cube, id=140141878427488)
3,获取被装饰函数get_concrete_function属性中的图结构所有的op节点
1)所有的op节点如下:
print(cube_func_int32.graph.get_operations())#---output-----[
,
,
,
]
2)其中具体op节点信息如下:
pow_op = cube_func_int32.graph.get_operations()[2]print(pow_op)#---output------------name: "Pow"op: "Pow"input: "x"input: "Pow/y"attr {
key: "T" value {
type: DT_INT32 }}
4,通过节点名称、tensor名称获取节点对象
# 通过节点名称获取节点对象print(cube_func_int32.graph.get_operation_by_name("x"))# 通过tensor名称获取节点对象print(cube_func_int32.graph.get_tensor_by_name("x:0"))#---output-----------name: "x"op: "Placeholder"attr {
key: "_user_specified_name" value {
s: "x" }}attr {
key: "dtype" value {
type: DT_INT32 }}attr {
key: "shape" value {
shape {
dim {
size: -1 } } }}Tensor("x:0", shape=(None,), dtype=int32)
5,获取所有的图结构信息
print(cube_func_int32.graph.as_graph_def())#---output-----node {
name: "x" op: "Placeholder" attr {
key: "_user_specified_name" value {
s: "x" } } attr {
key: "dtype" value {
type: DT_INT32 } } attr {
key: "shape" value {
shape {
dim {
size: -1 } } } }}node {
name: "Pow/y" op: "Const" attr {
key: "dtype" value {
type: DT_INT32 } } attr {
key: "value" value {
tensor {
dtype: DT_INT32 tensor_shape {
} int_val: 3 } } }}node {
name: "Pow" op: "Pow" input: "x" input: "Pow/y" attr {
key: "T" value {
type: DT_INT32 } }}node {
name: "Identity" op: "Identity" input: "Pow" attr {
key: "T" value {
type: DT_INT32 } }}versions {
producer: 175}

转载地址:http://yvili.baihongyu.com/

你可能感兴趣的文章
高并发
查看>>
MySQL常用设置
查看>>
Linux 运维常用网络命令
查看>>
JavaEE常用框架汇总
查看>>
分布式数据库汇总
查看>>
Vim 命令
查看>>
Flink
查看>>
NTP-网络时间协议
查看>>
C/C++学习方法
查看>>
Borland编译器,在windows7的命令行中运行C++
查看>>
Apache Derby 网络服务器 - 10.9.1.0 - (1344872) 已启动并准备接受端口 1527 上的连接
查看>>
Java日常常用小算法
查看>>
JavaSE经典编程示例
查看>>
Eclipse软件相关知识
查看>>
人工智能资料汇总--AI传送门
查看>>
百度地图SDKv4.1.1 错误码230
查看>>
Android百度地图SDK -- 环境搭建
查看>>
Android学习路线
查看>>
导航栏实现
查看>>
图文混排实现
查看>>