文章首发及后续更新:https://mwhls.top/2294.html
新的更新内容请到mwhls.top查看。
无图/无目录/格式错误/更多相关请到上方的文章首发页面查看。

基于NAIE平台的YOLOv5识别超超超可爱的逢坂大河

主要步骤
  • 模型训练时,将训练好的模型使用joblib / pickle 库保存成pkl文件。
  • 训练结束后,将结果归档,归档按钮为图1右下角删除按键右边的按钮。
  • 模型验证时,创建一个新的验证服务,写入验证代码后,选择刚刚归档的模型,用pkl文件恢复训练好的模型,将验证集加入验证。
图1:模型训练中的模型归档
说明
  • 模型训练中,保存模型时如果出现UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte
    • 以二进制形式写入文件即可,即将 'w' 改成 'wb'
  • 模型验证中,导入模型时如果出现[Errno 2] No such file or directory: 'model/model.pkl'
    • 可以用os.system('ls')的方式查看工作目录下的所有文件,看下模型文件到底保存在哪。
    • 也可以在模型管理中把模型下载下来,然后将pkl模型文件上传至模型验证的工作目录中。
  • 模型验证结束的指标如何体现还不知道,按照示例做了也没有效果。
    • 直接print输出也能看到结果就是了。
    • 但好像不能下载下来。
模型训练代码修改
  • 改动不大,只需要将模型训练完成后的模型用joblib / pickle保存下来即可。
  • 但因为之前我的代码训练了两个模型,一个LDA一个QDA,所以删了LDA。
  • 下面是需要改动的地方,后面会附上完整代码
#原函数
def train_and_test(train_result, fetures_list, test_list):
    #调用sklearn进行训练及预测
    letter_x = np.array(fetures_list)
    letter_y = np.array(train_result)
    clf = QDA()
    clf.fit(letter_x, letter_y)
    
    return list(clf.predict(test_list))

#修改后
def train_and_test(train_result, fetures_list, test_list):
    #调用sklearn进行训练及预测
    letter_x = np.array(fetures_list)
    letter_y = np.array(train_result)
    clf = QDA()
    clf.fit(letter_x, letter_y)
    
            #!!增加了保存模型,以便验证时读取。
    save_model(clf)
    
    return list(clf.predict(test_list))

def save_model(model):
    with open(os.path.join(Context.get_output_path(), 'model22.pkl'), 'wb') as ff:
        pickle.dump(model, ff)
模型验证代码流程
  1. 调入验证集数据
    • 像训练时一样即可。
  2. 使用pkl文件导入训练好的模型
    1. 官方教程的硬盘检测提供了一个方式,其自带的模板也提供了一个方式。
      • 改了多遍,都不能用。
    2. 最不用费脑子的方法,是将训练好的模型归档,然后下载下来,把里面的pkl文件拿出来,然后塞进模型验证代码所在的文件夹里,再调用它。
      • 这里本来写了很多吐槽的东西,但是还是算了,内容太长影响观感。
    3. 最后使用的方式是下方的代码
      • 是我第二天起床的时候想到的,可以用ls看一下文件夹情况,最后发现model.pkl实际上是在当前文件夹,而不是在model文件夹里。
      • 淦。
  3. 使用模型验证验证集
    • 像训练时一样即可。
model_path =  './model22.pkl'
with open(model_path, 'rb') as modelf:
    model = pickle.load(modelf)
模型训练代码
  • main.py
# -*- coding: utf-8 -*-
from __future__ import print_function  # do not delete this line if you want to save your log file.

import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA, QuadraticDiscriminantAnalysis as QDA
from naie.datasets import get_data_reference
from naie.metrics import report
import pickle
from naie.context import Context
import os
import joblib

'''
    代码框架见QDA.py,大致是 加载数据,模型训练,测试集预测,日志上传四个部分。
    补上了日志上传的代码,即score_model()
    补上了模型训练部分中模型保存的代码,即save_model()
'''


def get_list_letter(char, features, path):
    #数据集处理部分
    with open(path, 'r') as letter_train:
        content_letter_train = letter_train.readlines()
    for line in content_letter_train:
        temp = line.split(',')
        temp[-1] = list(temp[-1])[0]
        char.append(temp[0])
        features.append((temp[1::]))
        
        
def convert_int(str_list):
    #数据集处理部分
    for row in range(0, len(str_list)):
        for col in range(0, len(str_list[row])):
            str_list[row][col] = int(str_list[row][col])


def train_and_test(train_result, fetures_list, test_list):
    #调用sklearn进行训练及预测
    letter_x = np.array(fetures_list)
    letter_y = np.array(train_result)
    clf = QDA()
    clf.fit(letter_x, letter_y)
    
            #!!增加了保存模型,以便验证时读取。
    save_model(clf)
    
    return list(clf.predict(test_list))


def analysis_accuracy(judge_result, test_char):
    #分析预测精确度
    sum = 0
    right_num = 0
    for pos in range(0, len(judge_result)):
        sum += 1
        if judge_result[pos] == test_char[pos]:
            right_num += 1
    return right_num / sum


def score_model(accuracy):
            #    !!日志生成函数
    #NAIE官方模板,用于给出报告,下面的注释就是示例。
    
    with report(True) as log_report:
        log_report.log_property('Accuracy', accuracy)
        
    """
    ====================== YOUR CODE HERE ======================
    there are several api you can use here.
    Example:
    from naie.metrics import report
    with report(True) as log_report:
        log_report.log_property("score", accuracy_score(y_validation, model.predict(x_validation)))
    ============================================================
    """
    return accuracy


def save_model(model):
            #    !!保存模型函数,用pickle/joblib保存模型,方便模型验证时使用。
    """
    ====================== YOUR CODE HERE ======================
    write model to the specific model path of train job

    e.g.
    from naie.context import Context
    with open(os.path.join(Context.get_output_path(), 'model.pkl'), 'w') as ff:
        pickle.dump(clf, ff)
    or
    tf.estimator.Estimator(model_dir=Context.get_output_path())  # using tensorflow Estimator
    ============================================================
    """
    with open(os.path.join(Context.get_output_path(), 'model22.pkl'), 'wb') as ff:
        pickle.dump(model, ff)


def main():
    
    # letter数据集初始化
            #    !!这里的路径获取改了
    letter_train_path = get_data_reference(dataset="Default", dataset_entity="letter_train").get_files_paths()[0]
    letter_train_class = []
    letter_train_features = []
    letter_test_path = get_data_reference(dataset="Default", dataset_entity="letter_test").get_files_paths()[0]
    letter_test_class = []
    letter_test_features = []
    get_list_letter(letter_train_class, letter_train_features, letter_train_path)
    get_list_letter(letter_test_class, letter_test_features, letter_test_path)
    convert_int(letter_train_features)
    convert_int(letter_test_features)
    
    # Letter数据集学习
    letter_QDA_judge_result = train_and_test(letter_train_class, letter_train_features, letter_test_features)
    letter_QDA_judge_accuracy = analysis_accuracy(letter_QDA_judge_result, letter_test_class)
    print('使用QDA对letter的', len(letter_train_features), '份数据学习后,对',
        len(letter_test_features), '份测试数据分类的准确率为:', letter_QDA_judge_accuracy)
    
        #    !!增加了日志生成
    score = score_model(letter_QDA_judge_accuracy)
    
    return score


if __name__ == "__main__":
    main()
  • requirements.txt
    • 但似乎只要有naie就好,其它自带了。
#name [condition] [version]
#condition    ==, >=, <=, >, <
#tensorflow==1.8.1
naie
scikit-learn
numpy
pickle
joblib
模型验证代码
  • really_validation.py
    • 注:requirements.txt不用改,这些库似乎都自带了。
# -*- coding: utf-8 -*-
from __future__ import print_function
from naie.context import Context
from naie.datasets import get_data_reference
from naie.metrics import report
import os
import pickle


def get_list_letter(char, features, path):
    #数据集处理部分
    with open(path, 'r') as letter_train:
        content_letter_train = letter_train.readlines()
    for line in content_letter_train:
        temp = line.split(',')
        temp[-1] = list(temp[-1])[0]
        char.append(temp[0])
        features.append((temp[1::]))
        
        
def convert_int(str_list):
    #数据集处理部分
    for row in range(0, len(str_list)):
        for col in range(0, len(str_list[row])):
            str_list[row][col] = int(str_list[row][col])

            
def analysis_accuracy(judge_result, test_char):
    #分析预测精确度
    sum = 0
    right_num = 0
    for pos in range(0, len(judge_result)):
        sum += 1
        if judge_result[pos] == test_char[pos]:
            right_num += 1
    return right_num / sum
            
            
def model_validation():
    #验证集处理(实际上用的是测试集,因为没有验证集)
    letter_test_path = get_data_reference(dataset="Default", dataset_entity="letter_test").get_files_paths()[0]
    letter_test_class = []
    letter_test_features = []
    get_list_letter(letter_test_class, letter_test_features, letter_test_path)
    convert_int(letter_test_features)
    
    
    #有很多想说的话,但是还是算了,做文明人。
    #恢复模型,并用其验证
    model_path =  './model22.pkl'
    with open(model_path, 'rb') as modelf:
        model = pickle.load(modelf)
        judge_result = model.predict(letter_test_features)
        
        acc = analysis_accuracy(judge_result, letter_test_class)
        with report(True) as logs:
            print('log.')
            logs.log_property("acc", acc)
    
    print("accuracy: ", ' - ', acc, type(acc), "")
    if(acc > 0.7):
        print('success')
        return 1
    else:
        print('fail')
        return 0
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐