fasterrcnn目标预测时怎么过滤某个类别

faster rcnn目标识别,数据集是voc2012,有21个类别,怎么在预测代码里移除某一类目标。比如预测时忽略(不预测)aeroplan这一类别,或者不显示这类的边框。

import os
import time
import json

import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt

from torchvision import transforms
from network_files import FasterRCNN, FastRCNNPredictor, AnchorsGenerator
from backbone import resnet50_fpn_backbone, MobileNetV2
from draw_box_utils import draw_objs


def create_model(num_classes):

    # resNet50+fpn+faster_RCNN
    # 注意,这里的norm_layer要和训练脚本中保持一致
    backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d)
    model = FasterRCNN(backbone=backbone, num_classes=num_classes, rpn_score_thresh=0.5)

    return model


def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()


def main():
    # get devices
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # create model
    model = create_model(num_classes=21)


    # load train weights
    weights_path = "./fasterrcnn_voc2012.pth"
    assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
    weights_dict = torch.load(weights_path, map_location='cpu')
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
    model.load_state_dict(weights_dict)
    model.to(device)

    # read class_indict
    label_json_path = './pascal_voc_classes.json'
    assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
    with open(label_json_path, 'r') as f:
        class_dict = json.load(f)

    category_index = {str(v): str(k) for k, v in class_dict.items()}

    # load image
    original_img = Image.open("./correction_image./2007_004459_fisheye_12.jpg")

    # from pil image to tensor, do not normalize image
    data_transform = transforms.Compose([transforms.ToTensor()])
    img = data_transform(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    model.eval()  # 进入验证模式
    with torch.no_grad():
        # init
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        predictions = model(img.to(device))[0]
        t_end = time_synchronized()
        print("inference+NMS time: {}".format(t_end - t_start))

        predict_boxes = predictions["boxes"].to("cpu").numpy()
        predict_classes = predictions["labels"].to("cpu").numpy()
        predict_scores = predictions["scores"].to("cpu").numpy()

        if len(predict_boxes) == 0:
            print("没有检测到任何目标!")


        plot_img = draw_objs(original_img,
                             predict_boxes,
                             predict_classes,
                             predict_scores,
                             category_index=category_index,
                             box_thresh=0.5,
                             line_thickness=3,
                             font='arial.ttf',
                             font_size=20)
        plt.imshow(plot_img)
        plt.show()
        # 保存预测的图片结果
        plot_img.save("test_result.jpg")

if __name__ == '__main__':
    main()


img

不知道你这个问题是否已经解决, 如果还没有解决的话:
  • 这篇博客: 使用Faster RCNN训练自己的数据集中的 2.预训练模型编译 部分也许能够解决你的问题, 你可以仔细阅读以下内容或者直接跳转源博客中阅读:
    • 新建文件夹

    (注:本文将原文件夹重命名为faster-rcnn)在文件夹中新建data文件夹

    cd faster-rcnn && mkdir data

    data文件夹中新建pretrained_model文件夹

    mkdir pretrained_model
    • 下载预训练模型VGG16与ResNet-101

    预训练模型VGG16:

    VGG16

    预训练模型ResNet-101:ResNet-101

    将下载好的预训练模型放到pretrained_model文件夹中

    • 执行编译
    cd lib
    python setup.py build develop
    cd ..

    编译完成,如图所示

    如果执行编译后,训练自己的数据集仍然报错:

    ImportError: cannot import name '_mask'

    则是缺少COCO API,需要执行以下指令

    cd data
    
    git clone https://github.com/pdollar/coco.git 
    
    cd coco/PythonAPI
    
    make
    
    cd ../../..

    如图所示

    可以看到'_mask.o'已经编译成功

    • Scipy降版本

    使用pip查看已经安装的Python库

    pip list

    可以看到其中Scipy与Pillow版本分别问scipy==1.5.4与Pillow==8.2.0,由于Scipy版本自身的变动原因,需要对Scipy进行降版本,否则在训练中会报错

    ImportError: cannot import name 'imread' 

    首先卸载以上两个版本

    pip uninstall scipy
    
    pip uninstall pillow

    然后安装指定版本即可

    pip install scipy == 1.2.1
    pip install pillow == 6.1.0

如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^