YOLO-V3代码解析系列(二) —— 主程序结构(train.py)
代码展示#! /usr/bin/env python# coding=utf-8# ================================================================#Copyright (C) 2019 * Ltd. All rights reserved.##Editor: VIM#File name: train.py#Author: YunYa
·
代码展示
import os
import time
import numpy as np
import tensorflow as tf
import core.utils as utils
from tqdm import tqdm
from core.dataset import Dataset
from core.yolov3 import YOLOV3
from core.config import cfg
class YoloTrain(object):
def __init__(self):
# yolov3中每个grid cell三个anchor,三个尺度下,总共9个anchor
self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
self.classes = utils.read_class_names(cfg.YOLO.CLASSES)
self.num_classes = len(self.classes)
self.learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT
self.learn_rate_end = cfg.TRAIN.LEARN_RATE_END
self.first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS
self.second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS
self.warmup_epoch = cfg.TRAIN.WARMUP_EPOCHS
self.initial_weight = cfg.TRAIN.INITIAL_WEIGHT
self.time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
self.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY
self.max_bbox_per_scale = 150
self.train_logdir = "./train_log/mydata_test1/"
self.trainset = Dataset('train')
self.testset = Dataset('test')
self.steps_per_epoch = len(self.trainset)
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
with tf.name_scope('define_input'):
self.input_data = tf.placeholder(dtype=tf.float32, name='input_data')
self.label_sbbox = tf.placeholder(dtype=tf.float32, name='label_sbbox')
self.label_mbbox = tf.placeholder(dtype=tf.float32, name='label_mbbox')
self.label_lbbox = tf.placeholder(dtype=tf.float32, name='label_lbbox')
self.true_sbboxes = tf.placeholder(dtype=tf.float32, name='sbboxes')
self.true_mbboxes = tf.placeholder(dtype=tf.float32, name='mbboxes')
self.true_lbboxes = tf.placeholder(dtype=tf.float32, name='lbboxes')
self.trainable = tf.placeholder(dtype=tf.bool, name='training')
with tf.name_scope("define_loss"):
# 自定义网络yolo-v3的全局参数, 修改了分类的数量
self.model = YOLOV3(self.input_data, self.trainable)
self.net_var = tf.global_variables()
self.giou_loss, self.conf_loss, self.prob_loss = self.model.compute_loss(
self.label_sbbox, self.label_mbbox, self.label_lbbox,
self.true_sbboxes, self.true_mbboxes, self.true_lbboxes)
self.loss = self.giou_loss + self.conf_loss + self.prob_loss
with tf.name_scope('learn_rate'):
self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step')
warmup_steps = tf.constant(self.warmup_epoch * self.steps_per_epoch,
dtype=tf.float64,
name='warmup_steps')
# 总的训练步数
train_steps = tf.constant((self.first_stage_epochs + self.second_stage_epochs) * self.steps_per_epoch,
dtype=tf.float64,
name='train_steps')
# 带有warmup的学习率余弦衰减
self.learn_rate = tf.cond(pred=self.global_step < warmup_steps,
true_fn=lambda: self.global_step / warmup_steps * self.learn_rate_init,
false_fn=lambda: self.learn_rate_end + 0.5 * (
self.learn_rate_init - self.learn_rate_end) *
(1 + tf.cos((self.global_step - warmup_steps) / (
train_steps - warmup_steps) * np.pi))
)
# 添加更新global_step的Op
global_step_update = tf.assign_add(self.global_step, 1.0)
# 添加移动均值,保证训练的稳定,以及后续测试的稳定
with tf.name_scope("define_weight_decay"):
moving_ave = tf.train.ExponentialMovingAverage(self.moving_ave_decay).apply(tf.trainable_variables())
# 第一阶段优化
with tf.name_scope("define_first_stage_train"):
# 存储第一阶段需要优化的参数
self.first_stage_trainable_var_list = []
for var in tf.trainable_variables():
var_name = var.op.name
var_name_mess = str(var_name).split('/')
# 只优化最后一层权重参数, 其它层固定不动(freeze)
if var_name_mess[0] in ['conv_sbbox', 'conv_mbbox', 'conv_lbbox']:
self.first_stage_trainable_var_list.append(var)
print('warm learning rate: ', self.learn_rate)
first_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
var_list=self.first_stage_trainable_var_list)
# 控制变量操作:保证操作的执行顺序
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies([first_stage_optimizer, global_step_update]):
with tf.control_dependencies([moving_ave]):
self.train_op_with_frozen_variables = tf.no_op()
# 第二阶段优化
with tf.name_scope("define_second_stage_train"):
second_stage_trainable_var_list = tf.trainable_variables()
second_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
var_list=second_stage_trainable_var_list)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies([second_stage_optimizer, global_step_update]):
with tf.control_dependencies([moving_ave]):
self.train_op_with_all_variables = tf.no_op()
with tf.name_scope('loader_and_saver'):
# 恢复除去最后一层的所有训练参数
variables_to_restore = []
for v in self.net_var:
if v.name.split('/')[0] not in ['conv_sbbox', 'conv_mbbox', 'conv_lbbox']:
variables_to_restore.append(v)
# 恢复全部的变量,适用于第二阶段训练
# self.loader = tf.train.Saver(self.net_var)
# 恢复部分的变量,用于加载coco预训练权重
self.loader = tf.train.Saver(variables_to_restore)
# 保存模型
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=300)
with tf.name_scope('summary'):
tf.summary.scalar("learn_rate", self.learn_rate)
tf.summary.scalar("giou_loss", self.giou_loss)
tf.summary.scalar("conf_loss", self.conf_loss)
tf.summary.scalar("prob_loss", self.prob_loss)
tf.summary.scalar("total_loss", self.loss)
if os.path.exists(self.train_logdir):
# shutil.rmtree(self.train_logdir)
print('pass')
else:
os.mkdir(self.train_logdir)
self.write_op = tf.summary.merge_all()
self.summary_writer = tf.summary.FileWriter(self.train_logdir, graph=self.sess.graph)
def train(self):
self.sess.run(tf.global_variables_initializer())
start_epoch = 1
try:
print('=> Restoring weights from: %s ... ' % self.initial_weight)
self.loader.restore(self.sess, self.initial_weight)
if len(str(self.initial_weight).split('-')) == 2:
start_epoch = int(str(self.initial_weight).split('-')[1])
print('new start epoch: ', start_epoch)
print('成功加载模型')
except:
print('Load Pretrained Weight Failed !!!\n')
print('=> %s does not exist !!!' % self.initial_weight)
print('=> Now it starts to train YOLOV3 from scratch ...')
self.first_stage_epochs = 0 # 重头开始训练网络, 不需要预训练权重
else:
print('\nLoad Model Success !!!\n')
for epoch in range(start_epoch+1, 1 + self.first_stage_epochs + self.second_stage_epochs):
if epoch < self.first_stage_epochs:
train_op = self.train_op_with_frozen_variables
else:
print('==================================')
train_op = self.train_op_with_all_variables
# tqdm封装可迭代对象,然后循环调用该对象,所以
# 每一次调用对象,则返回一个batch的数据,如此循环
pbar = tqdm(self.trainset)
train_epoch_loss, test_epoch_loss = [], []
for train_data in pbar:
_, summary, train_step_loss, global_step_val, learn_rate_ = self.sess.run(
[train_op, self.write_op, self.loss, self.global_step, self.learn_rate], feed_dict={
self.input_data: train_data[0],
self.label_sbbox: train_data[1],
self.label_mbbox: train_data[2],
self.label_lbbox: train_data[3],
self.true_sbboxes: train_data[4],
self.true_mbboxes: train_data[5],
self.true_lbboxes: train_data[6],
self.trainable: True,
})
train_epoch_loss.append(train_step_loss)
self.summary_writer.add_summary(summary, global_step_val)
pbar.set_description("train loss: %.2f learning rate: %.5f global_step: %3d" % (train_step_loss,
learn_rate_,
global_step_val))
# 测试集上测试模型
for test_data in self.testset:
test_step_loss = self.sess.run(self.loss, feed_dict={
self.input_data: test_data[0],
self.label_sbbox: test_data[1],
self.label_mbbox: test_data[2],
self.label_lbbox: test_data[3],
self.true_sbboxes: test_data[4],
self.true_mbboxes: test_data[5],
self.true_lbboxes: test_data[6],
self.trainable: False,
})
test_epoch_loss.append(test_step_loss)
train_epoch_loss, test_epoch_loss = np.mean(train_epoch_loss), np.mean(test_epoch_loss)
ckpt_file = self.train_logdir + "loss=%.4f.ckpt" % test_epoch_loss
log_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
print("=> Epoch: %2d Time: %s Train loss: %.2f Test loss: %.2f Saving %s ..."
% (epoch, log_time, train_epoch_loss, test_epoch_loss, ckpt_file))
self.saver.save(self.sess, ckpt_file, global_step=epoch)
if __name__ == '__main__':
yolo = YoloTrain()
yolo.train()
代码总结
train.py
控制整个代码的整体执行逻辑,是整个程序执行的起点。
代码中涉及一些训练tricks
,具体如下:
- 学习率:使用带有
warmup
的余弦衰减,先使用小的学习率使得网络训练稳定.fine-tune
加载预训练权重,参考https://blog.csdn.net/kxh123456/article/details/106452717.ExponentialMovingAverage
的使用,参考https://blog.csdn.net/kxh123456/article/details/106716320.
更多推荐
已为社区贡献1条内容
所有评论(0)