tensorflow创建变量以及根据名称查找变量
|
环境:Ubuntu14.04,tensorflow=1.4(bazel源码安装),Anaconda python=3.6 声明变量主要有两种方法:tf.Variable和 tf.get_variable,二者的最大区别是: (1) tf.Variable是一个类,自带很多属性函数;而 tf.get_variable是一个函数; 以变量共享时,可以重复使用该变量(例如RNN中的参数共享)。
import tensorflow as tf
with tf.variable_scope('scope1',reuse=tf.AUTO_REUSE) as scope1:
x1 = tf.Variable(tf.ones([1]),name='x1')
x2 = tf.Variable(tf.zeros([1]),name='x1')
y1 = tf.get_variable('y1',initializer=1.0)
y2 = tf.get_variable('y1',initializer=0.0)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(x1.name,x1.eval())
print(x2.name,x2.eval())
print(y1.name,y1.eval())
print(y2.name,y2.eval())
输出结果为: scope1/x1:0 [ 1.] scope1/x1_1:0 [ 0.] scope1/y1:0 1.0 scope1/y1:0 1.0 1. tf.Variable(…) tf.Variable(…)使用给定初始值来创建一个新变量,该变量会默认添加到 graph collections listed in collections,which defaults to [GraphKeys.GLOBAL_VARIABLES]。 如果trainable属性被设置为True,该变量同时也会被添加到graph collection GraphKeys.TRAINABLE_VARIABLES. # tf.Variable __init__( initial_value=None,trainable=True,collections=None,validate_shape=True,caching_device=None,name=None,variable_def=None,dtype=None,expected_shape=None,import_scope=None,constraint=None ) 2. tf.get_variable(…) tf.get_variable(…)的返回值有两种情形: 使用指定的initializer来创建一个新变量; get_variable( name,shape=None,initializer=None,regularizer=None,partitioner=None,use_resource=None,custom_getter=None,constraint=None ) 3. 根据名称查找变量 在创建变量时,即使我们不指定变量名称,程序也会自动进行命名。于是,我们可以很方便的根据名称来查找变量,这在抓取参数、finetune模型等很多时候都很有用。 示例1: 通过在tf.global_variables()变量列表中,根据变量名进行匹配搜索查找。 该种搜索方式,可以同时找到由tf.Variable或者tf.get_variable创建的变量。
import tensorflow as tf
x = tf.Variable(1,name='x')
y = tf.get_variable(name='y',shape=[1,2])
for var in tf.global_variables():
if var.name == 'x:0':
print(var)
示例2: 利用get_tensor_by_name()同样可以获得由tf.Variable或者tf.get_variable创建的变量。
import tensorflow as tf
x = tf.Variable(1,2])
graph = tf.get_default_graph()
x1 = graph.get_tensor_by_name("x:0")
y1 = graph.get_tensor_by_name("y:0")
示例3: 针对tf.get_variable创建的变量,可以利用变量重用来直接获取已经存在的变量。
with tf.variable_scope("foo"):
bar1 = tf.get_variable("bar",(2,3)) # create
with tf.variable_scope("foo",reuse=True):
bar2 = tf.get_variable("bar") # reuse
with tf.variable_scope("",reuse=True): # root variable scope
bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
print((bar1 is bar2) and (bar2 is bar3))
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持编程小技巧。 您可能感兴趣的文章:
(编辑:安卓应用网) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |
