目标检测评估指标mAP的计算-python

目标检测性能评估指标mAP介绍

为解决不同场景下对目标检测的位置偏差的需求不同,通常给定一个IOU阈值,超过此阈值则视为检测成功。以及考虑到类别平衡的问题,通常分别求每一个类别的性能,再进行类别间求平均。

那么给定一个IOU阈值以及一个特定的类别,如何求这个类别的AP值呢?

首先对所有的检测结果排序,得分越高越靠前,然后依次判断检测是否成功。先求出检测结果dets和真实目标gts的iou值,并找出每一个det有最大iou值的gts,若这个最大的iou超过了一定的阈值,则det对这个gt匹配,视为TP,注意每一个gt只能被匹配一次。若某个det对所有的gt的iou都没有超过阈值,或者最大iou超过阈值的gt已经被匹配,则视为FP。得分从高到低遍历完所有的det之后,若仍有gt没有匹配,则视为FN。

得到TP、FP、FN之后,通过precision=TP/(TP+FP)和recall=TP/(TP+FN)可以计算出精确率和召回率。但是不同任务对于精确度和召回率的要求不一样,有些任务要求更高的recall,“错检”是可以接受的,而有些任务要求更高的precision,“漏检”是可以接受的。为了对recall和precision进行整体的评估,我们选取排好序的det前n个(n=1,2,3…),从而得到一条recall-precision曲线图,通过计算这条曲线下的面积得到AP值。

对每个类别的AP值取平均即可得到mAP。

具体的计算步骤:
  1. 首先将每一张图片对应的某个类别的预测值和标注信息提取出来。
  2. 对于每一张图片求其TP和FP。将bboxes按照分数进行排序,然后计算bboxes与gts之间的iou值,找出与bbox有最大iou值的gt框,从分数最高的bbox开始,若bbox与gt的最大iou值大于阈值,则将gt与这个bbox匹配,tp[gt]=1,(注意,每一个gt框只能被一个bbox匹配)。其余情况下,fp[gt]=1。
  3. 得到所有图片的某一类别的tp和fp(由0,1组成的数组,对应每个bbox是tp还是fp)合并成一个数组,并将其按照bbox的分数从大到小进行排序,再将排序后的tp和fp进行累加得到中间结果数组(保存了限制det的不同数量下的tp和fp数量)。
  4. 根据recalls=tp/gts_num、precisions=tp/(tp+fp)计算出召回率和精确率的数组(不同dets数量的限制)。
  5. 使用11points(11点插值法)或者area方法求AP值。
  6. 将计算得到的各个类别的AP值取平均得到mAP值。
详细代码
import numpy as np
from multiprocessing import Pool
def eval_map(det_results, annotations, iou_thr=0.5, nproc=4):
    '''
    Params:
        det_results(List[List]): [[cls1_det, cls2_det, ...], ...].      #包含每一张图片每一个类别的预测结果
        annotations(List[dict]):                                        #包含每一张图片的数据标注
            dict = {
                'bboxes':(ndarray) shape(n,4)
                'labels':(ndarray) shape(n)
            }
    '''
    assert len(det_results) == len(annotations)
    
    num_imgs = len(det_results)
    num_scales = 1
    num_classes = len(det_results[0])
    area_ranges = None

    pool = Pool(nproc)
    eval_results = []
    for i in range(num_classes):
        cls_dets, cls_gts = get_cls_results(det_results, annotations, i)
        tpfp = pool.starmap(
            get_tpfp, zip(cls_dets, cls_gts, [iou_thr for _ in range(num_imgs)])
        )
        tp,fp = tuple(zip(*tpfp))
        num_gts = np.zeros(num_scales, dtype=int)
        for j, bbox in enumerate(cls_gts):
            num_gts[0] += bbox.shape[0]
        cls_dets = np.vstack(cls_dets)
        num_dets = cls_dets.shape[0]
        sort_inds = np.argsort(-cls_dets[:, -1])
        tp = np.hstack(tp)[sort_inds]               #按分数从大到小排序
        fp = np.hstack(fp)[sort_inds]

        tp = np.cumsum(tp)                          #数组累加
        fp = np.cumsum(fp)
        eps = np.finfo(np.float32).eps
        recalls = tp/np.maximum(num_gts[0], eps)
        precisions = tp/np.maximum((tp+fp), eps)
        num_gts = num_gts.item()
        #print(recalls, precisions)
        ap = average_precisions(recalls, precisions, mode='11points')
        eval_results.append({
            'num_gts': num_gts,
            'num_dets': num_dets,
            'recall': recalls,
            'precision': precisions,
            'ap': ap 
        })
    pool.close()
    aps = []
    for cls_result in eval_results:
        if cls_result['num_gts'] > 0:
            aps.append(cls_result['ap'])
    mean_ap = np.array(aps).mean().item() if aps else 0.0

    return mean_ap, eval_results

def get_cls_results(det_results, annotations, class_id):
    '''
    Params:
        det_results(List[List]):[[cls1_det, cls2_det, ...], ...].
        annotations(List[dict])
    Return:

    '''
    cls_dets = [img_res[class_id] for img_res in det_results]
    cls_gts = []
    for ann in annotations:
        gt_inds = ann['labels'] == class_id+1
        cls_gts.append(ann['bboxes'][gt_inds, :])
    return cls_dets, cls_gts

def get_tpfp(det_bboxes, gt_bboxes, iou_thr=0.5):           #每一张图片的tpfp
    '''
    Params:
        det_bboxes(ndarray): shape(m,5)                     #前4个是坐标,最后1个是分数
        gt_bboxes(ndarray): shape(n,4)
    '''
    num_dets = det_bboxes.shape[0]
    num_gts = gt_bboxes.shape[0]
    
    tp = np.zeros((num_dets), dtype=np.float32)
    fp = np.zeros((num_dets), dtype=np.float32)

    if gt_bboxes.shape[0] == 0:
        fp[...] = 1
        return tp, fp

    ious = bbox_overlaps(det_bboxes[:,:-1], gt_bboxes)
    ious_max = ious.max(axis=1)                 #找出每个预测框有最大IoU值的真实框
    ious_argmax = ious.argmax(axis=1)
    sort_inds = np.argsort(-det_bboxes[:,-1])
    gt_covered = np.zeros(num_gts, dtype=bool)
    for i in sort_inds:
        if ious_max[i] >= iou_thr:
            matched_gt = ious_argmax[i]         #匹配对应的真实框
            if not gt_covered[matched_gt]:      #若真实框没有被匹配,则匹配之
                gt_covered[matched_gt] = True
                tp[i] = 1
            else:
                fp[i] = 1
        else:
            fp[i] = 1
    
    return tp, fp
    
def bbox_overlaps(bboxes1, bboxes2, eps=1e-6):
    """Calculate the ious between each bbox of bboxes1 and bboxes2.
    Args:
        bboxes1(ndarray): shape (n, 4)
        bboxes2(ndarray): shape (k, 4)
    Returns:
        ious(ndarray): shape (n, k)
    """
    bboxes1 = bboxes1.astype(np.float32)
    bboxes2 = bboxes2.astype(np.float32)
    rows = bboxes1.shape[0]
    cols = bboxes2.shape[0]
    ious = np.zeros((rows, cols), dtype=np.float32)
    if rows * cols == 0:
        return ious
    exchange = False
    if bboxes1.shape[0] > bboxes2.shape[0]:
        bboxes1, bboxes2 = bboxes2, bboxes1
        ious = np.zeros((cols, rows), dtype=np.float32)
        exchange = True
    area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
    area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
    for i in range(bboxes1.shape[0]):
        x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
        y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
        x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
        y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
        overlap = np.maximum(x_end - x_start, 0) * np.maximum(
            y_end - y_start, 0)
        union = area1[i] + area2 - overlap
        union = np.maximum(union, eps)
        ious[i, :] = overlap / union
    if exchange:
        ious = ious.T
    return ious

def average_precisions(recalls, precisions, mode='area'):
    '''
    Params:
        recalls(ndarray): shape(num_dets)
        precisions(ndarray): shape(num_dets)
    '''
    assert recalls.shape == precisions.shape
    ap = np.zeros(1,dtype=np.float32)
    if mode == 'area':
        zero = np.zeros(1, dtype=recalls.dtype)
        one = np.ones(1, dtype=recalls.dtype)
        mrec = np.hstack((zero, recalls, one))
        mpre = np.hstack((zero, precisions, zero))
        for i in range(mpre.shape[0]-1, 0, -1):
            mpre[i-1] = np.max(mpre[i-1], mpre[i])
        ind = np.where(mrec[1:] != mrec[:-1])[0]
        ap[0] = np.sum((mrec[ind+1] - mrec[ind]) * mpre[ind+1])
    elif mode == '11points': 
        for thr in np.arange(0, 1+1e-3, 0.1):
            precs = precisions[recalls >= thr]
            prec = precs.max() if precs.size > 0 else 0
            ap[0] += prec
        ap[0] /= 11
    return ap[0]

def show_mAP_table(eval_results, mAP):
    label_len = len('classes')
    dets_len = len('dets')
    gts_len = len('gts')
    for i, res in enumerate(eval_results):
        label_len = max(len(labels[i]), label_len)
        dets_len = max(len(str(res['num_dets'])), dets_len)
        gts_len = max(len(str(res['num_gts'])), gts_len)
    s1 = '+' + (label_len+2)*'-' + '+' + (dets_len+2)*'-' + '+' + (gts_len+2)*'-' + '+' + 7*'-' + '+'
    header = '| classes' + (label_len-6)*' ' + '|' + ' dets' + (dets_len-3)*' ' + '|' + ' gts' + (gts_len-2)*' ' + '|' +'  mAP  ' + '|'
    print(s1)
    print(header)
    print(s1)
    for i ,res in enumerate(eval_results):
        l_len = len(labels[i])
        d_len = len(str(res['num_dets']))
        g_len = len(str(res['num_gts']))
        ap = "{:.3f}".format(res['ap'])
        content = '| ' + labels[i] + (label_len-l_len+1)*' ' + '| ' + str(res['num_dets']) + (dets_len-d_len+1)*' ' + '| ' + str(res['num_gts']) + (gts_len-g_len+1)*' ' + '| ' + str(ap) + ' |'
        print(content)
    print(s1)
    mAP = "{:.3f}".format(mAP)
    content = '| mAP' + (label_len-2)*' ' + '|' + (dets_len+2)*' ' + '|' + (gts_len+2)*' ' + '| ' + str(mAP) + ' |'
    print(content)
    print(s1)
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐