三十六-tensorflow的session和graph
tensorflow作为一个基于图结构的深度学习框架,内部通过session实现图和计算内核的交互,那么这个图是什么样的结构,session的工作原理又是什么样的呢?我们通过几段代码来深入理解一下
tensorflow中的基本数学运算用法¶
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | import tensorflow as tf sess = tf.Session() a = tf.placeholder("float") b = tf.placeholder("float") c = tf.constant(6.0) d = tf.mul(a, b) y = tf.mul(d, c) print sess.run(y, feed_dict={a: 3, b: 3}) A = [[1.1,2.3],[3.4,4.1]] Y = tf.matrix_inverse(A) print sess.run(Y) sess.close() | 
主要数字运算还包括:
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | tf.add tf.sub tf.mul tf.div tf.mod tf.abs tf.neg tf.sign tf.inv tf.square tf.round tf.sqrt tf.pow tf.exp tf.log tf.maximum tf.minimum tf.cos tf.sin | 
主要矩阵运算还包括:
| 1 2 3 4 5 | tf.diag生成对角阵 tf.transpose tf.matmul tf.matrix_determinant计算行列式的值 tf.matrix_inverse计算矩阵的逆 | 
插播小甜点:tensorboard使用 tensorflow因为代码执行过程是先构建图,然后在执行,所以对中间过程的调试不太方便,所以提供了一个tensorboard工具来便于调试,用法如下:
在训练时会提示写入事件文件到哪个目录(比如:/tmp/tflearn_logs/11U8M4/)
执行如下命令并打开http://192.168.1.101:6006就能看到tensorboard的界面
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | tensorboard --logdir=/tmp/tflearn_logs/11U8M4/ ``` ### 什么是Graph和Session 为了步入正题,我们通过一段代码来展示Graph和Session的使用 ```python import tensorflow as tf with tf.Graph().as_default() as g: with g.name_scope("myscope") as scope: # 有了这个scope,下面的op的name都是类似myscope/Placeholder这样的前缀 sess = tf.Session(target='', graph = g, config=None) # target表示要连接的tf执行引擎 print "graph version:", g.version # 0 a = tf.placeholder("float") print a.op # 输出整个operation信息,跟下面g.get_operations返回结果一样 print "graph version:", g.version # 1 b = tf.placeholder("float") print "graph version:", g.version # 2 c = tf.placeholder("float") print "graph version:", g.version # 3 y1 = tf.mul(a, b) # 也可以写成a * b print "graph version:", g.version # 4 y2 = tf.mul(y1, c) # 也可以写成y1 * c print "graph version:", g.version # 5 operations = g.get_operations() for (i, op) in enumerate(operations): print "============ operation", i+1, "===========" print op # 一个结构,包括:name、op、attr、input等,不同op不一样 assert y1.graph is g assert sess.graph is g print "================ graph object address ================" print sess.graph print "================ graph define ================" print sess.graph_def print "================ sess str ================" print sess.sess_str print sess.run(y1, feed_dict={a: 3, b: 3}) # 9.0 feed_dictgraph中的元素和值的映射 print sess.run(fetches=[b,y1], feed_dict={a: 3, b: 3}, options=None, run_metadata=None) # 传入的feches和返回值的shape相同 print sess.run({'ret_name':y1}, feed_dict={a: 3, b: 3}) # {'ret_name': 9.0} 传入的feches和返回值的shape相同 assert tf.get_default_session() is not sess with sess.as_default(): # 把sess作为默认的session,那么tf.get_default_session就是sess, 否则不是 assert tf.get_default_session() is sess h = sess.partial_run_setup([y1, y2], [a, b, c]) # 分阶段运行,参数指明了feches和feed_dict列表 res = sess.partial_run(h, y1, feed_dict={a: 3, b: 4}) # 12 运行第一阶段 res = sess.partial_run(h, y2, feed_dict={c: res}) # 144.0 运行第二阶段,其中使用了第一阶段的执行结果 print "partial_run res:", res sess.close() | 
输出如下:
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | graph version: 0 name: "myscope/Placeholder" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } graph version: 1 graph version: 2 graph version: 3 graph version: 4 graph version: 5 ============ operation 1 =========== name: "myscope/Placeholder" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } ============ operation 2 =========== name: "myscope/Placeholder_1" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } ============ operation 3 =========== name: "myscope/Placeholder_2" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } ============ operation 4 =========== name: "myscope/Mul" op: "Mul" input: "myscope/Placeholder" input: "myscope/Placeholder_1" attr { key: "T" value { type: DT_FLOAT } } ============ operation 5 =========== name: "myscope/Mul_1" op: "Mul" input: "myscope/Mul" input: "myscope/Placeholder_2" attr { key: "T" value { type: DT_FLOAT } } ================ graph object address ================ <tensorflow.python.framework.ops.Graph object at 0x1138702d0> ================ graph define ================ node { name: "myscope/Placeholder" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } } node { name: "myscope/Placeholder_1" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } } node { name: "myscope/Placeholder_2" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { } } } } node { name: "myscope/Mul" op: "Mul" input: "myscope/Placeholder" input: "myscope/Placeholder_1" attr { key: "T" value { type: DT_FLOAT } } } node { name: "myscope/Mul_1" op: "Mul" input: "myscope/Mul" input: "myscope/Placeholder_2" attr { key: "T" value { type: DT_FLOAT } } } versions { producer: 15 } ================ sess str ================ 9.0 [array(3.0, dtype=float32), 9.0] {'ret_name': 9.0} partial_run res: 144.0 | 
tensorflow的Session是如何工作的¶
Session是Graph和执行者之间的媒介,Session.run()实际上将graph、fetches、feed_dict序列化到字节数组中,并调用tf_session.TF_Run(参见/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py)
而这里的tf_session.TF_Run实际上调用了动态链接库_pywrap_tensorflow.so中实现的_pywrap_tensorflow.TF_Run接口(参见/usr/local/lib/python2.7/site-packages/tensorflow/python/pywrap_tensorflow.py),这个动态链接库是tensorflow提供的诸多语言接口中python语言的接口
事实上这里的_pywrap_tensorflow.so和pywrap_tensorflow.py是通过SWIG工具自动生成,大家都知道tensorflow核心语言是c语言,这里是通过SWIG生成了各种脚本语言的接口