tensorflow 實現模型微調finetune

前端時間完成了一個分類模型的訓練,在後期測試時發現分類模型存在一些場景或者角度沒法準確分類。一般咱們每每會決定對這些場景及角度數據進行補充而後從新訓練模型。若是數據量較少,模型完成訓練的時間則較少。假若數據量十分龐大(如我此次訓練的圖片有600W+),那麼完成模型訓練時間則較長,整體會拖延項目進度。模型從新訓練也不能徹底保證新增的數據能達到較好的分類效果,此時可使用微調的方法,可大大減小模型訓練的時間並能保持以前模型較好的訓練效果。前端

微調策略,可選用以前測試效果較好的模型進行微調處理。微調選用的學習率設置爲預訓練模型快達到擬合時前的學習率,將該預訓練模型加載完成後迭代1-2輪便可。例如:預訓練模型在學習率衰減達到0.001時達到擬合,此時微調時學習率設置爲0.005進行訓練。python

如個人預加載模型以下:git

 

微調部分代碼以下:(其中checkpoint_path = './ckpt/step_291800_loss_0.00985_acc_1.00000'ide

def preTrain(checkpoint_path):
    model_path_restore = checkpoint_path + '.ckpt'
    dataset = get_record_dataset(record_path=Config.record_path,     
         num_samples=Config.num_samples,num_classes=Config.num_classes)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    # print('image:',image)
    image, label = processing_image(image,label)
    images, labels = tf.train.batch([image, label], batch_size=Config.BATCH_SIZE, num_threads=1, capacity=5)

    logist = Model(images, is_training=True, num_classes=Config.num_classes)

    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logist)
    loss = tf.reduce_mean(cross_entropy)

    correct_prediction = tf.equal(tf.argmax(labels, 1), tf.argmax(logist, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # # 量化
    graph = tf.get_default_graph()
    create_training_graph(graph, 40000)

    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(learning_rate=Config.learning_rate,
                                    global_step=tf.cast(tf.div(global_step, 40),
                                    tf.int32),
                                    decay_steps=Config.decay_steps,
                                    decay_rate=Config.decay_rate,
                                    staircase=True)

    # lr = 0.001
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss,
          global_step=global_step)

    global_init = tf.global_variables_initializer()
    total_step = Config.NUM_EPOCH * Config.num_samples // Config.BATCH_SIZE

    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=12)

    with tf.Session(config=gpuConfig) as sess:
        # rimages, rlabels = sess.run([images, labels])
        # print('--------rimages:---------',rimages)
        # init = tf.initialize_local_variables()
        # sess.run([init])
        sess.run([global_init])
        #加載預訓練模型
        saver.restore(sess, model_path_restore)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            for step in range(1, total_step + 1):
                if coord.should_stop():
                    break

                _, loss_val, accuracy_val, global_step_test = sess.run([train_op, loss, 
                                                              accuracy, global_step])
                lr_val = sess.run(lr)
                print('global_step',global_step_test)
                print("step: %d,lr: = %.5f,loss = %.5f,accuracy =%.5f" % (step, lr_val, 
                                                         loss_val, accuracy_val))
                #模型保存
                if (step == 1):
                    tf.train.write_graph(sess.graph, Config.pb_path, "handcnet.pb")
                if step % 200 == 0:
                    saver.save(sess, Config.ckpt_path + 
                    "step_%d_loss_%.5f_acc_%.5f.ckpt" % (step, loss_val, accuracy_val))
                    print('Save for ', Config.ckpt_path + "step_%d_loss_%.5f.ckpt" % 
                                                                (step, loss_val))

        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            coord.request_stop()

        coord.join(threads)

        saver.save(sess, Config.ckpt_path + "completed_model.ckpt" % loss_val)

        tf.train.write_graph(sess.graph, Config.pb_path, "handcnet.pb")

        print("train completed!")

        sess.close()

      本次TensorFlow finetune針對我的分類模型完成的,若每次更新數據都從新訓練模型都須要幾天的時間(數據量較大),採用該方法,在獲得新的數據時進行數據預處理後與以前的數據進行打亂整合,加載完預訓練模型後迭代1-2輪便可完成,模型測試時也有更可以擬合新增數據的效果。學習