代码展示

import os
import time
import numpy as np
import tensorflow as tf
import core.utils as utils
from tqdm import tqdm
from core.dataset import Dataset
from core.yolov3 import YOLOV3
from core.config import cfg


class YoloTrain(object):
    def __init__(self):
        # yolov3中每个grid cell三个anchor,三个尺度下,总共9个anchor
        self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
        self.classes = utils.read_class_names(cfg.YOLO.CLASSES)
        self.num_classes = len(self.classes)

        self.learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT
        self.learn_rate_end = cfg.TRAIN.LEARN_RATE_END
        self.first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS
        self.second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS
        self.warmup_epoch = cfg.TRAIN.WARMUP_EPOCHS
		
        self.initial_weight = cfg.TRAIN.INITIAL_WEIGHT
        self.time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        self.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY
        self.max_bbox_per_scale = 150 
        
        self.train_logdir = "./train_log/mydata_test1/"
        self.trainset = Dataset('train')
        self.testset = Dataset('test')
        self.steps_per_epoch = len(self.trainset)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)

        with tf.name_scope('define_input'):
            self.input_data = tf.placeholder(dtype=tf.float32, name='input_data')
            self.label_sbbox = tf.placeholder(dtype=tf.float32, name='label_sbbox')
            self.label_mbbox = tf.placeholder(dtype=tf.float32, name='label_mbbox')
            self.label_lbbox = tf.placeholder(dtype=tf.float32, name='label_lbbox')
            self.true_sbboxes = tf.placeholder(dtype=tf.float32, name='sbboxes')
            self.true_mbboxes = tf.placeholder(dtype=tf.float32, name='mbboxes')
            self.true_lbboxes = tf.placeholder(dtype=tf.float32, name='lbboxes')
            self.trainable = tf.placeholder(dtype=tf.bool, name='training')

        with tf.name_scope("define_loss"):
            # 自定义网络yolo-v3的全局参数, 修改了分类的数量
            self.model = YOLOV3(self.input_data, self.trainable)

            self.net_var = tf.global_variables()
            self.giou_loss, self.conf_loss, self.prob_loss = self.model.compute_loss(
                self.label_sbbox, self.label_mbbox, self.label_lbbox,
                self.true_sbboxes, self.true_mbboxes, self.true_lbboxes)

            self.loss = self.giou_loss + self.conf_loss + self.prob_loss

        with tf.name_scope('learn_rate'):
            self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step')
            warmup_steps = tf.constant(self.warmup_epoch * self.steps_per_epoch,
                                       dtype=tf.float64,
                                       name='warmup_steps')

            # 总的训练步数
            train_steps = tf.constant((self.first_stage_epochs + self.second_stage_epochs) * self.steps_per_epoch,
                                      dtype=tf.float64,
                                      name='train_steps')
			# 带有warmup的学习率余弦衰减
            self.learn_rate = tf.cond(pred=self.global_step < warmup_steps,
                                      true_fn=lambda: self.global_step / warmup_steps * self.learn_rate_init,
                                      false_fn=lambda: self.learn_rate_end + 0.5 * (
                                              self.learn_rate_init - self.learn_rate_end) *
                                                       (1 + tf.cos((self.global_step - warmup_steps) / (
                                                               train_steps - warmup_steps) * np.pi))
                                      )
            # 添加更新global_step的Op
            global_step_update = tf.assign_add(self.global_step, 1.0)
		
		# 添加移动均值,保证训练的稳定,以及后续测试的稳定
        with tf.name_scope("define_weight_decay"):
            moving_ave = tf.train.ExponentialMovingAverage(self.moving_ave_decay).apply(tf.trainable_variables())
		
		# 第一阶段优化
        with tf.name_scope("define_first_stage_train"):
            # 存储第一阶段需要优化的参数
            self.first_stage_trainable_var_list = []
            for var in tf.trainable_variables():
                var_name = var.op.name
                var_name_mess = str(var_name).split('/')

                # 只优化最后一层权重参数, 其它层固定不动(freeze)
                if var_name_mess[0] in ['conv_sbbox', 'conv_mbbox', 'conv_lbbox']:
                    self.first_stage_trainable_var_list.append(var)

            print('warm learning rate: ', self.learn_rate)
            first_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,                                                                      
															                         var_list=self.first_stage_trainable_var_list)
            # 控制变量操作:保证操作的执行顺序                                                    
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                with tf.control_dependencies([first_stage_optimizer, global_step_update]):
                    with tf.control_dependencies([moving_ave]):
                        self.train_op_with_frozen_variables = tf.no_op()
		# 第二阶段优化
        with tf.name_scope("define_second_stage_train"):
            second_stage_trainable_var_list = tf.trainable_variables()
            second_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
                                                            var_list=second_stage_trainable_var_list)

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                with tf.control_dependencies([second_stage_optimizer, global_step_update]):
                    with tf.control_dependencies([moving_ave]):
                        self.train_op_with_all_variables = tf.no_op()

        with tf.name_scope('loader_and_saver'):
            # 恢复除去最后一层的所有训练参数
            variables_to_restore = []
            for v in self.net_var:
                if v.name.split('/')[0] not in ['conv_sbbox', 'conv_mbbox', 'conv_lbbox']:
                    variables_to_restore.append(v)
			
			# 恢复全部的变量,适用于第二阶段训练
            # self.loader = tf.train.Saver(self.net_var)
            # 恢复部分的变量,用于加载coco预训练权重
            self.loader = tf.train.Saver(variables_to_restore)
			# 保存模型
            self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=300)

        with tf.name_scope('summary'):
            tf.summary.scalar("learn_rate", self.learn_rate)
            tf.summary.scalar("giou_loss", self.giou_loss)
            tf.summary.scalar("conf_loss", self.conf_loss)
            tf.summary.scalar("prob_loss", self.prob_loss)
            tf.summary.scalar("total_loss", self.loss)

            if os.path.exists(self.train_logdir):
                # shutil.rmtree(self.train_logdir)
                print('pass')
            else:
                os.mkdir(self.train_logdir)

            self.write_op = tf.summary.merge_all()
            self.summary_writer = tf.summary.FileWriter(self.train_logdir, graph=self.sess.graph)

    def train(self):
        self.sess.run(tf.global_variables_initializer())
        start_epoch = 1
        try:
            print('=> Restoring weights from: %s ... ' % self.initial_weight)
            self.loader.restore(self.sess, self.initial_weight)

            if len(str(self.initial_weight).split('-')) == 2:
                start_epoch = int(str(self.initial_weight).split('-')[1])
                print('new start epoch: ', start_epoch)
            print('成功加载模型')
        except:
            print('Load Pretrained Weight Failed !!!\n')
            print('=> %s does not exist !!!' % self.initial_weight)
            print('=> Now it starts to train YOLOV3 from scratch ...')
            self.first_stage_epochs = 0  # 重头开始训练网络, 不需要预训练权重
        else:
            print('\nLoad Model Success !!!\n')

        for epoch in range(start_epoch+1, 1 + self.first_stage_epochs + self.second_stage_epochs):
            if epoch < self.first_stage_epochs:
                train_op = self.train_op_with_frozen_variables
            else:
                print('==================================')
                train_op = self.train_op_with_all_variables
			
			# tqdm封装可迭代对象,然后循环调用该对象,所以
			# 每一次调用对象,则返回一个batch的数据,如此循环
            pbar = tqdm(self.trainset)
            train_epoch_loss, test_epoch_loss = [], []

            for train_data in pbar:
                _, summary, train_step_loss, global_step_val, learn_rate_ = self.sess.run(
                    [train_op, self.write_op, self.loss, self.global_step, self.learn_rate], feed_dict={
                        self.input_data: train_data[0],
                        self.label_sbbox: train_data[1],
                        self.label_mbbox: train_data[2],
                        self.label_lbbox: train_data[3],
                        self.true_sbboxes: train_data[4],
                        self.true_mbboxes: train_data[5],
                        self.true_lbboxes: train_data[6],
                        self.trainable: True,
                    })

                train_epoch_loss.append(train_step_loss)
                self.summary_writer.add_summary(summary, global_step_val)
                pbar.set_description("train loss: %.2f learning rate: %.5f global_step: %3d" % (train_step_loss,
                                                                                                learn_rate_,
                                                                                                global_step_val))
			# 测试集上测试模型
            for test_data in self.testset:
                test_step_loss = self.sess.run(self.loss, feed_dict={
                    self.input_data: test_data[0],
                    self.label_sbbox: test_data[1],
                    self.label_mbbox: test_data[2],
                    self.label_lbbox: test_data[3],
                    self.true_sbboxes: test_data[4],
                    self.true_mbboxes: test_data[5],
                    self.true_lbboxes: test_data[6],
                    self.trainable: False,
                })

                test_epoch_loss.append(test_step_loss)

            train_epoch_loss, test_epoch_loss = np.mean(train_epoch_loss), np.mean(test_epoch_loss)
            ckpt_file = self.train_logdir + "loss=%.4f.ckpt" % test_epoch_loss
            log_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
            print("=> Epoch: %2d Time: %s Train loss: %.2f Test loss: %.2f Saving %s ..."
                  % (epoch, log_time, train_epoch_loss, test_epoch_loss, ckpt_file))
            self.saver.save(self.sess, ckpt_file, global_step=epoch)


if __name__ == '__main__':
    yolo = YoloTrain()
    yolo.train()

代码总结

train.py控制整个代码的整体执行逻辑,是整个程序执行的起点。
代码中涉及一些训练tricks,具体如下:

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐