官方目前训练完自带的predict.py里面只有识别图片的,没有识别视频的。

各种百度踩坑后,今天写成!opencv-python 是 4.5.5.64版本的,比较新,所以语法有可能会不一样。请注意根据自己的版本修改,或者直接升级opencv-python

# 脚本运行依赖paddlex
# pip install paddlex
import cv2
import os
import paddlex as pdx
import argparse
import json


def pred(args):
    if(args.use_gpu == False):
        predictor = pdx.deploy.Predictor(args.model_dir)
    else :
        predictor = pdx.deploy.Predictor(args.model_dir, use_gpu=True)
    
    if(args.img_file is not None):
        try:
            result = predictor.predict(img_file=args.img_file)
            print("result:" + json.dumps(result))
            pdx.det.visualize(args.img_file, result, threshold=0.5, save_dir=args.save_dir)
        except Exception as e :
            print(e) # 最好包一层 try catch 有时会识别不出来 会报 list index out range
    elif(args.video_file is not None):
        images_mats = video_to_image(args.video_file, args.fps)
        fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
        videoWriter = cv2.VideoWriter(args.save_dir + getfilename(args.video_file), fourcc, int(args.fps), (int(images_mats[2]),int(images_mats[1])), True)

        for image in images_mats[0]:
            try:
                result = predictor.predict(img_file=image)
                print("result:" + json.dumps(result))
                image_res = pdx.det.visualize(image, result, threshold=0.5, save_dir=None)
                
                videoWriter.write(image_res)
            except Exception as e:
                videoWriter.write(image)
                print(e) # 最好包一层 try catch 有时会识别不出来 会报 list index out range
        videoWriter.release()
    return result

def video_to_image(video_path, nfps):
    """
    视频解析为图片到指定文件夹
    :param video_path:视频路径
    :param nfps:每秒多少张图片
    :return:
    """
    # 加载视频文件
    camera = cv2.VideoCapture(video_path)
    # 帧数
    times = 0
    images = []
    # 帧率(frames per second) 原视频的
    fps = camera.get(cv2.CAP_PROP_FPS)
    # 总帧数(frames)
    frames = camera.get(cv2.CAP_PROP_FRAME_COUNT)
    # 视频高度
    frame_height = camera.get(cv2.CAP_PROP_FRAME_HEIGHT )
    # 视频宽带
    frame_width = camera.get(cv2.CAP_PROP_FRAME_WIDTH)
    # 求要取的帧数
    frame_frequency = int(fps/nfps)

    print("帧数:"+str(fps))
    print("总帧数:"+str(frames))
    print("视屏总时长:"+"{0:.2f}".format(frames/fps)+"秒")
    while True:
        times += 1
        res, image = camera.read()
        if not res:
            break
        if times % frame_frequency == 0:
            images.append(image)
    return images, frame_height, frame_width

def getfilename(path):
    return os.path.basename(path) #输出为 1.mp4

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--use_gpu', default=False, type=bool)
    parser.add_argument('--model_dir', default='./inference_model', type=str)
    parser.add_argument('--gpu_id ', default=0, type=int)
    parser.add_argument('--img_file', default=None, type=str)
    parser.add_argument('--video_file', default=None, type=str)
    parser.add_argument('--fps', default=5, type=int) #识别后的视频的fps,考虑到识别效率一般设置小一点的数
    parser.add_argument('--save_dir', default='./output/', type=str)
    args = parser.parse_args()

    args.video_file="C:\\Users\\hurui\\Desktop\\1.mp4" # 方便vscode直接F5调试
    
    if (args.img_file is None and args.video_file is None):    
        sys.exit(1)
    
    return args

if __name__ == '__main__':
    args = parse_args()
    result = pred(args)
    
    if result is not None:
        print('Done!')
    else:
        print('Does not any objects!')

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐