三十六-tensorflow的session和graph

tensorflow作为一个基于图结构的深度学习框架,内部通过session实现图和计算内核的交互,那么这个图是什么样的结构,session的工作原理又是什么样的呢?我们通过几段代码来深入理解一下

tensorflow中的基本数学运算用法

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import tensorflow as tf

sess = tf.Session()

a = tf.placeholder("float")
b = tf.placeholder("float")
c = tf.constant(6.0)
d = tf.mul(a, b)
y = tf.mul(d, c)
print sess.run(y, feed_dict={a: 3, b: 3})

A = [[1.1,2.3],[3.4,4.1]]
Y = tf.matrix_inverse(A)
print sess.run(Y)
sess.close()

主要数字运算还包括:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
tf.add
tf.sub
tf.mul
tf.div
tf.mod
tf.abs
tf.neg
tf.sign
tf.inv
tf.square
tf.round
tf.sqrt
tf.pow
tf.exp
tf.log
tf.maximum
tf.minimum
tf.cos
tf.sin

主要矩阵运算还包括:

1
2
3
4
5
tf.diag生成对角阵
tf.transpose
tf.matmul
tf.matrix_determinant计算行列式的值
tf.matrix_inverse计算矩阵的逆

插播小甜点:tensorboard使用 tensorflow因为代码执行过程是先构建图,然后在执行,所以对中间过程的调试不太方便,所以提供了一个tensorboard工具来便于调试,用法如下:

在训练时会提示写入事件文件到哪个目录(比如:/tmp/tflearn_logs/11U8M4/)

执行如下命令并打开http://192.168.1.101:6006就能看到tensorboard的界面

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
tensorboard --logdir=/tmp/tflearn_logs/11U8M4/
 ```

### 什么是Graph和Session

为了步入正题,我们通过一段代码来展示GraphSession的使用

```python
import tensorflow as tf

with tf.Graph().as_default() as g:
    with g.name_scope("myscope") as scope: # 有了这个scope,下面的op的name都是类似myscope/Placeholder这样的前缀
        sess = tf.Session(target='', graph = g, config=None) # target表示要连接的tf执行引擎
        print "graph version:", g.version # 0
        a = tf.placeholder("float")
        print a.op # 输出整个operation信息,跟下面g.get_operations返回结果一样
        print "graph version:", g.version # 1
        b = tf.placeholder("float")
        print "graph version:", g.version # 2
        c = tf.placeholder("float")
        print "graph version:", g.version # 3
        y1 = tf.mul(a, b) # 也可以写成a * b
        print "graph version:", g.version # 4
        y2 = tf.mul(y1, c) # 也可以写成y1 * c
        print "graph version:", g.version # 5
        operations = g.get_operations()
        for (i, op) in enumerate(operations):
            print "============ operation", i+1, "==========="
            print op # 一个结构,包括:name、op、attr、input等,不同op不一样
        assert y1.graph is g
        assert sess.graph is g
        print "================ graph object address ================"
        print sess.graph
        print "================ graph define ================"
        print sess.graph_def
        print "================ sess str ================"
        print sess.sess_str
        print sess.run(y1, feed_dict={a: 3, b: 3}) # 9.0 feed_dictgraph中的元素和值的映射
        print sess.run(fetches=[b,y1], feed_dict={a: 3, b: 3}, options=None, run_metadata=None) # 传入的feches和返回值的shape相同
        print sess.run({'ret_name':y1}, feed_dict={a: 3, b: 3}) # {'ret_name': 9.0} 传入的feches和返回值的shape相同

        assert tf.get_default_session() is not sess
        with sess.as_default(): # 把sess作为默认的session,那么tf.get_default_session就是sess, 否则不是
            assert tf.get_default_session() is sess

        h = sess.partial_run_setup([y1, y2], [a, b, c]) # 分阶段运行,参数指明了feches和feed_dict列表
        res = sess.partial_run(h, y1, feed_dict={a: 3, b: 4}) # 12 运行第一阶段
        res = sess.partial_run(h, y2, feed_dict={c: res}) # 144.0 运行第二阶段,其中使用了第一阶段的执行结果
        print "partial_run res:", res
        sess.close()

输出如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
graph version: 0
name: "myscope/Placeholder"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}

graph version: 1
graph version: 2
graph version: 3
graph version: 4
graph version: 5
============ operation 1 ===========
name: "myscope/Placeholder"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}

============ operation 2 ===========
name: "myscope/Placeholder_1"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}

============ operation 3 ===========
name: "myscope/Placeholder_2"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}

============ operation 4 ===========
name: "myscope/Mul"
op: "Mul"
input: "myscope/Placeholder"
input: "myscope/Placeholder_1"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}

============ operation 5 ===========
name: "myscope/Mul_1"
op: "Mul"
input: "myscope/Mul"
input: "myscope/Placeholder_2"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}

================ graph object address ================
<tensorflow.python.framework.ops.Graph object at 0x1138702d0>
================ graph define ================
node {
  name: "myscope/Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "myscope/Placeholder_1"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "myscope/Placeholder_2"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "myscope/Mul"
  op: "Mul"
  input: "myscope/Placeholder"
  input: "myscope/Placeholder_1"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "myscope/Mul_1"
  op: "Mul"
  input: "myscope/Mul"
  input: "myscope/Placeholder_2"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 15
}

================ sess str ================

9.0
[array(3.0, dtype=float32), 9.0]
{'ret_name': 9.0}
partial_run res: 144.0

tensorflow的Session是如何工作的

Session是Graph和执行者之间的媒介,Session.run()实际上将graph、fetches、feed_dict序列化到字节数组中,并调用tf_session.TF_Run(参见/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py)

而这里的tf_session.TF_Run实际上调用了动态链接库_pywrap_tensorflow.so中实现的_pywrap_tensorflow.TF_Run接口(参见/usr/local/lib/python2.7/site-packages/tensorflow/python/pywrap_tensorflow.py),这个动态链接库是tensorflow提供的诸多语言接口中python语言的接口

事实上这里的_pywrap_tensorflow.so和pywrap_tensorflow.py是通过SWIG工具自动生成,大家都知道tensorflow核心语言是c语言,这里是通过SWIG生成了各种脚本语言的接口