from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import input_data
import time
import tensorflow as tf
import model
# 定义常量,用于创建数据流图
flags = tf.app.flags
# task_index从0开始。0代表用来初始化变量的第一个任务
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
# 每台机器GPU个数,机器没有GPU为0
flags.DEFINE_integer("num_gpus", 0,
"Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'")
# 同步训练模型下,设置收集工作节点数量。默认工作节点总数
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
# 学习效率
flags.DEFINE_float("learning_rate", 0.0001, "Learning rate")
# 使用同步训练、异步训练
flags.DEFINE_boolean("sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
# 如果服务器已经存在,采用gRPC协议通信;如果不存在,采用进程间通信
flags.DEFINE_boolean(
"existing_servers", False, "Whether servers already exists. If True, "
"will use the worker hosts via their GRPC URLs (one client process "
"per worker host). Otherwise, will create an in-process TensorFlow "
"server.")
# 参数服务器主机
flags.DEFINE_string("ps_hosts","localhost:2222",
"Comma-separated list of hostname:port pairs")
# 工作节点主机
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
# 本作业是工作节点还是参数服务器
flags.DEFINE_string("job_name", None,"job name: worker or ps")
tf.app.flags.DEFINE_string("train_dir", "", "This is training dir")
tf.app.flags.DEFINE_string("logs_train_dir", "", "This is training log dir")
tf.app.flags.DEFINE_integer("IMG_W", 208, "Cut the image correct wideth")
tf.app.flags.DEFINE_integer("IMG_H", 208, "Cut the image correct High")
tf.app.flags.DEFINE_integer("CAPACITY", 256, "The tensorflow capacity")
tf.app.flags.DEFINE_integer("MAX_STEP", 150, "The max step")
tf.app.flags.DEFINE_integer("N_CLASSES", 8, "The classes will be")
tf.app.flags.DEFINE_integer("BATCH_SIZE", 32, "The tensorflow batch size")
FLAGS = flags.FLAGS
def main(unused_argv):
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
#Construct the cluster and start the server
# 读取集群描述信息
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
# Get the number of workers.
num_workers = len(worker_spec)
# 创建TensorFlow集群描述对象
cluster = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
# 为本地执行任务创建TensorFlow Server对象。
if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server.
# 创建本地Sever对象,从tf.train.Server这个定义开始,每个节点开始不同
# 根据执行的命令的参数(作业名字)不同,决定这个任务是哪个任务
# 如果作业名字是ps,进程就加入这里,作为参数更新的服务,等待其他工作节点给它提交参数更新的数据
# 如果作业名字是worker,就执行后面的计算任务
server = tf.train.Server(
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 如果是参数服务器,直接启动即可。这里,进程就会阻塞在这里
# 下面的tf.train.replica_device_setter代码会将参数批定给ps_server保管
if FLAGS.job_name == "ps":
server.join()
# 处理工作节点
# 找出worker的主节点,即task_index为0的点
is_chief = (FLAGS.task_index == 0)
# 如果使用gpu
if FLAGS.num_gpus > 0:
# Avoid gpu allocation conflict: now allocate task_num -> #gpu
# for each worker in the corresponding machine
gpu = (FLAGS.task_index % FLAGS.num_gpus)
# 分配worker到指定gpu上运行
worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
# 如果使用cpu
elif FLAGS.num_gpus == 0:
# Just allocate the CPU to worker server
# 把cpu分配给worker
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
# The ps use CPU and workers use corresponding GPU
# 用tf.train.replica_device_setter将涉及变量操作分配到参数服务器上,使用CPU。将涉及非变量操作分配到工作节点上,使用上一步worker_device值。
# 在这个with语句之下定义的参数,会自动分配到参数服务器上去定义。如果有多个参数服务器,就轮流循环分配
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
with tf.variable_scope('inputdata') as scope:
# 获取图片和标签集
train, train_label = input_data.read_img(FLAGS.train_dir)
# 生成批次
train_batch, train_label_batch = input_data.get_batch(train,
train_label,
FLAGS.IMG_W,
FLAGS.IMG_H,
FLAGS.BATCH_SIZE,
FLAGS.CAPACITY)
# 定义全局步长,默认值为0
global_step = tf.Variable(0, name="global_step", trainable=False)
train_logits = model.inference(train_batch, FLAGS.BATCH_SIZE, FLAGS.N_CLASSES)
train_loss = model.losses(train_logits, train_label_batch)
accuracy = model.evaluation(train_logits, train_label_batch)
# merge all summaries into a single "operation" which we can execute in a session
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
print("Variables initialized ...")
# 异步训练模式:自己计算完成梯度就去更新参数,不同副本之间不会去协调进度
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
# 同步训练模式
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
# 使用SyncReplicasOptimizer作优化器,并且是在图间复制情况下
# 在图内复制情况下将所有梯度平均
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
name="mnist_sync_replicas")
train_step = opt.minimize(train_loss, global_step=global_step)
if FLAGS.sync_replicas:
local_init_op = opt.local_step_init_op
if is_chief:
# 所有进行计算工作节点里一个主工作节点(chief)
# 主节点负责初始化参数、模型保存、概要保存
local_init_op = opt.chief_init_op
ready_for_local_
没有合适的资源?快使用搜索试试~ 我知道了~
基于tensorflow分布式训练的CNN图像识别,基于自己的图片数据集开发.zip
共4个文件
py:4个
0 下载量 32 浏览量
2024-03-27
16:53:28
上传
评论
收藏 10KB ZIP 举报
温馨提示
人工智能-深度学习-tensorflow
资源推荐
资源详情
资源评论
收起资源包目录
基于tensorflow分布式训练的CNN图像识别,基于自己的图片数据集开发.zip (4个子文件)
Distribute-Tensorflow-CNN-master
evaluateDisease.py 5KB
model.py 5KB
distribute-inputdata.py 3KB
distribute-trainingV2.py 12KB
共 4 条
- 1
资源评论
博士僧小星
- 粉丝: 1700
- 资源: 5876
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功