tensorflow系列(3)分布式tensorflow

Posted by grt1stnull on 2017-07-20

多机如何分布式运行tensorflow模型?

0x00.前言

对于比较复杂的模型,在本机或者单服务器上跑起来需要很长时间。在很多科研单位或公司,可能没有插满gpu的服务器,这时候怎么办呢,有没有可能多台服务器一起跑一个模型呢?

这里就要用到分布式的tensorflow(distributed tensorflow)。

下面介绍在集群上部署tensorflow的方法。

0x01.基本概念

在分布式tensorflow中,服务器被分为两类,一类叫做参数服务器(parameter server,简称ps),另一类叫做计算服务器(worker)。顾名思义,ps会存储参数,分发参数;而worker运行模型,与ps就参数进行交互。

1.训练方式

tensorflow中常用的并行化训练方式有同步模式和异步模式两种方式。

在同步模式中,worker同时读取参数,但是训练完成后不会单独对参数进行更新,而是等待所有worker运行完,统一更新参数。

而在异步训练中,不同worker会对参数独立的更新。

0x02.tensorflow官方示例

tensorflow的官方代码在https://github.com/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py,下面我给示例代码打了一些注释,有条件的朋友可以尝试跑一下

1.变量设置

首先设置tf.app.flags定义标记,在命令行执行时,可指定相应参数的值。

1
2
3
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS

是否开启同步并行。

1
2
3
4
flags.DEFINE_boolean("sync_replicas", True,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")

在多少个batch后更新模型的参数(在同步更新中)。

1
2
3
4
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")

ps服务器、worker服务器地址的设置信息。

1
2
3
4
flags.DEFINE_string("ps_hosts","10.10.19.7:2222",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",
"Comma-separated list of hostname:port pairs")

job_name、task_index的定义,通常是通过命令行指定,不需要手动填写。

1
2
3
4
5
flags.DEFINE_string("job_name", None,"job name: worker or ps")
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")

判断是否填写job_name、task_index。

1
2
3
4
5
6
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)

从变量中解析ps、worker服务器。

1
2
3
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
num_workers = len(worker_spec)

2.分布式配置

创建tf中的cluster对象以及server:

1
2
3
4
5
cluster = tf.train.ClusterSpec({"ps": ps_spec,"worker": worker_spec})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 判断是否为主节点
is_chief = (FLAGS.task_index == 0)

计算资源配置,这里仅使用cpu。如果是ps服务器,则只需要等待worker服务器工作即可。

1
2
3
4
if FLAGS.job_name == "ps":
server.join()
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)

资源配置

1
2
3
4
with tf.device(tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):

3.训练准备

全局步数记录

1
global_step = tf.Variable(0, name="global_step", trainable=False)

同步模式需要对优化器进行扩展,所以假如有优化器opt = tf.train.AdamOptimizer(FLAGS.learning_rate),则有:

1
2
3
4
5
6
7
8
9
if FLAGS.sync_replicas:
# n batch后更新模型参数
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
# 创建新的优化器
opt = tf.train.SyncReplicasOptimizer(opt,replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,name="mnist_sync_replicas")

优化器:

1
train_step = opt.minimize(cross_entropy, global_step=global_step)

初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if FLAGS.sync_replicas:
local_init_op = opt.local_step_init_op
if is_chief:
local_init_op = opt.chief_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
# 队列执行器
chief_queue_runner = opt.get_chief_queue_runner()
# 全局参数初始化器
sync_init_op = opt.get_init_tokens_op()
# 本地参数初始化
init_op = tf.global_variables_initializer()
# 临时训练目录
train_dir = tempfile.mkdtemp()

分布式训练监督器创建:

1
2
3
4
5
6
7
8
9
if FLAGS.sync_replicas:
sv = tf.train.Supervisor(is_chief=is_chief,logdir=train_dir,
init_op=init_op,local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
recovery_wait_secs=1,global_step=global_step)
else:
sv = tf.train.Supervisor(is_chief=is_chief,logdir=train_dir,
init_op=init_op,recovery_wait_secs=1,
global_step=global_step)

设置sess的参数:

1
2
sess_config = tf.ConfigProto(allow_soft_placement=True,log_device_placement=False,
device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])

准备运行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.task_index)
else:
print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index)
# 等待/准备session
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
# 全局参数初始化器
sess.run(sync_init_op)
# 队列化执行器
sv.start_queue_runners(sess, [chief_queue_runner])

4.开始训练:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
time_begin = time.time()
print("Training begins @ %f" % time_begin)
local_step = 0
while True:
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
train_feed = {x: batch_xs, y_: batch_ys}
_, step = sess.run([train_step, global_step], feed_dict=train_feed)
local_step += 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# 测试集
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))

0x03.服务器上实际操作

更改tensorflow官方示例中的ps、worker服务器的ip,之后文件代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
#coding:utf-8
# 只是用了cpu
flags.DEFINE_integer("num_gpus", 0,
"Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'")
# ps服务器、worker服务器地址设置
flags.DEFINE_string("ps_hosts","10.10.19.7:2222",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",
"Comma-separated list of hostname:port pairs")

在三台服务器上依次运行:

python distribute_test.py --job_name=ps --task_index=0 --sync_replicas=True

python distribute_test.py --job_name=worker --task_index=0 --sync_replicas=True

python distribute_test.py --job_name=worker --task_index=1 --sync_replicas=True

在ps服务器上可以看到输出信息:

1
2
3
4
5
6
7
8
9
Extracting /tmp/mnist-data/train-images-idx3-ubyte.gz
Extracting /tmp/mnist-data/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist-data/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gz
job name = ps
task index = 0
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> localhost:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> 10.10.19.8:2222, 1 -> 10.10.19.9:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222

另外两台worker上:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Extracting /tmp/mnist-data/train-images-idx3-ubyte.gz
Extracting /tmp/mnist-data/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist-data/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gz
job name = worker
task index = 0
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> 10.10.19.7:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> localhost:2222, 1 -> 10.10.19.9:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222
Worker 0: Initializing session...
I tensorflow/core/distributed_runtime/master_session.cc:1012] Start master session df31e159ecf5dc77 with config:
device_filters: "/job:ps"
device_filters: "/job:worker/task:0"
allow_soft_placement: true
Worker 0: Session initialization complete.
Training begins @ 1500442384.091448
1500442384.150300: Worker 0: training step 1 done (global step: 0)
1500442384.163003: Worker 0: training step 2 done (global step: 0)
1500442384.172685: Worker 0: training step 3 done (global step: 1)
1500442384.182413: Worker 0: training step 4 done (global step: 1)
......
1500442387.524158: Worker 0: training step 269 done (global step: 198)
1500442387.539484: Worker 0: training step 270 done (global step: 199)
1500442387.555133: Worker 0: training step 271 done (global step: 200)
Training ends @ 1500442387.555215
Training elapsed time: 3.463767 s
After 200 training step(s), validation cross entropy = 781.478

worker1如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Extracting /tmp/mnist-data/train-images-idx3-ubyte.gz
Extracting /tmp/mnist-data/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist-data/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist-data/t10k-labels-idx1-ubyte.gz
job name = worker
task index = 1
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> 10.10.19.7:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> 10.10.19.8:2222, 1 -> localhost:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222
Worker 1: Waiting for session to be initialized...
I tensorflow/core/distributed_runtime/master_session.cc:1012] Start master session 540b3d300aac1583 with config:
device_filters: "/job:ps"
device_filters: "/job:worker/task:1"
allow_soft_placement: true
Worker 1: Session initialization complete.
Training begins @ 1500442385.490577
1500442385.520064: Worker 1: training step 1 done (global step: 68)
1500442385.534573: Worker 1: training step 2 done (global step: 69)
1500442385.549380: Worker 1: training step 3 done (global step: 70)
......
1500442387.524271: Worker 1: training step 131 done (global step: 198)
1500442387.539464: Worker 1: training step 132 done (global step: 199)
1500442387.555071: Worker 1: training step 133 done (global step: 200)
Training ends @ 1500442387.555124
Training elapsed time: 2.064547 s
After 200 training step(s), validation cross entropy = 781.478

0x04.打造自己的代码

这里我在lstm上尝试使用分布式,对上面的代码进行了大量的复用。

1.通用代码

首先设置tensorflow的基本标志:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
flags = tf.app.flags
flags.DEFINE_boolean("sync_replicas", True,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
flags.DEFINE_string("ps_hosts","10.10.19.7:2222",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "10.10.19.8:2222,10.10.19.9:2222",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("job_name", None,"job name: worker or ps")
flags.DEFINE_integer("task_index", 0,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
flags.DEFINE_integer("train_steps", 500,
"Number of (global) training steps to perform")
FLAGS = flags.FLAGS
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)

之后是配置:

1
2
3
4
5
6
7
8
9
10
11
12
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
num_workers = len(worker_spec)
cluster = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
is_chief = (FLAGS.task_index == 0)
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)

配置计算资源:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
global_step = tf.Variable(0, name="global_step", trainable=False)
# tf.placeholder...
# weight\bias
# 这里是计算
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
name="mnist_sync_replicas")
train_step = optimizer.minimize(cost, global_step=global_step)
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
if FLAGS.sync_replicas:
local_init_op = optimizer.local_step_init_op
if is_chief:
local_init_op = optimizer.chief_init_op
ready_for_local_init_op = optimizer.ready_for_local_init_op
chief_queue_runner = optimizer.get_chief_queue_runner()
sync_init_op = optimizer.get_init_tokens_op()
init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()
if FLAGS.sync_replicas:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
recovery_wait_secs=1,
global_step=global_step)
else:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
recovery_wait_secs=1,
global_step=global_step)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.task_index)
else:
print("Worker %d: Waiting for session to be initialized..." %
FLAGS.task_index)
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
sess.run(sync_init_op)
sv.start_queue_runners(sess, [chief_queue_runner])

开始计算,这里while true[train_step, global_step]的组织形式很重要。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
time_begin = time.time()
print("Training begins @ %f" % time_begin)
local_step = 0
while True:
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
_, step = sess.run([train_step, global_step], feed_dict={x: batch_x, y: batch_y})
local_step += 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)

2.运行调试

同样的在服务器上运行,异步模式要比同步模式快一倍,而准确率相差不大。

0x05.参考