Grad-CAM热力图可视化

原创
2022/09/02 06:43
阅读数 742

原理篇

Grad-CAM热力图可视化属于深度学习可解释性的研究范畴。深度学习常被认为是一种黑盒的操作,对于图像分类来说,最终它的依据是什么,我们需要进行可解释性的探究。对于上图中的狗和猫,我们需要特定的类别来绘制出特定的区域。对于狗,热力图上的显示的区域正是狗所在的区域,红色的部分是关注的比较多的部分,蓝色的是关注的少的部分。对于猫,热力图关注的是猫所在的区域。这样就能直观的显示出来对于某些特定类别到底是图像上的哪些区域来导致了这种正确的分类。另外我们也可以改进我们的算法来对所关注的区域有没有更加的契合。

上图是关于动作的分类,左边是刷牙,右边是锯树。对于刷牙来说,我们需要关注的是牙刷和口所在的区域;对于锯树来说,我们需要关注的是电锯和人头部所在的区域。

CAM(Class Activation Mapping类激活图)

在Grad-CAM之前,就有CAM的研究,它是来自于麻省理工学院(MIT)的研究人员。

在CAM之中就可以对于图片分类所关注的图像所在的区域进行高亮,比如狗的头部,母鸡的头部,杠铃两端的杠铃片,钟楼上的钟都能够进行高亮。

CAM需要修改原始网络,用全局池化层替换全连接层,并重新训练模型。在上图中我们可以看到,在最后所得到的feature map的多个通道进行全局池化,每一个通道的特征图会变成一个值,如绿色的特征图会对应绿色的神经元,红色的特征图对应红色的神经元,蓝色的特征图对应蓝色的神经元。然后进行分类,分类的时候可以得到相应的权重W1、W2、Wn。然后对于特征图的每个通道进行加权求和,所用的权重就是之前得到的W1、W2、Wn等。进行相加以后,由于最后的特征图一般会小于原始图像,所以还需要一个上采样的操作,上采样到原图尺寸并且和原图叠加就可以得到类激活图CAM。

Grad-CAM(Visual Explanations from Deep Networks via Gradient-based Localization)

Grad-CAM更加的灵活方便,不需要全局池化层,也不需要重新训练网络。Grad是梯度的意思,这里可以翻译为梯度类激活图。

在上图中,Input就是输入的图像,经过CNN卷积之后,得到特征图的输出,然后再经过一些特定任务的网络(比如图像分类、图像字幕、可视化问答等)的softmax之前的分类得分,将所需要的类别的梯度设置为1,其他类别的梯度都设置为0来进行反向传播到感兴趣的纠正过(Rectified)的卷积特征图,求和之后得到不同通道特征图的权重,再用这些权重对特征图不同通道进行加权求和,再经过一个ReLu运算,得到Grad-CAM热力图,表示模型必须在哪里作出特定决定。最后还可以将热力图与guided backpropagation(引导反向传播)逐点相乘,以获得高分辨率和概念特定的Guided Grad-CAM可视化,这一步也可以不做。

  • Grad-CAM的梯度

之前我们说了,首先我们有对最后一个特征图的特定类别c的在softmax之前的得分,以及对于最后一层卷积层输出的特征图上的像素(这里k指的是通道,第k个通道的特征图,i、j指的是特征图上的像素)求偏导就是相应的梯度。该梯度是可以通过反向传播来获得的。然后做一个全局平均池化,就是对于每一个i和j进行一个梯度的求和,然后除以整个特征图的面积Z(Z等于特征图的宽度*高度)。运算以后得到权重,这个权重反应出了特征图的k通道对于目标类别c的贡献的重要程度。

得到k个通道权重后,再对所有通道进行加权求和(这里就是第k个通道上的特征图),再经过一个ReLU之后就得到相应的值。当然得到的热力图的大小和特征图的大小是一样的,它还需要上采样变到和原图的大小一样再叠加到原图上得到最终的可视化的效果。

这里引入ReLU操作是为了得到有正影响的感应区,如果不进行该操作,将可能会绘制出对于主目标不太重要的区域,这个是没必要的。

在上图中,最左边的a是原始的图片,上面一排是对于猫的可视化,下面一排是对于狗的可视化。b、h是Guided Backprop,c、i是Grad-CAM,b和c(或者h和i)相乘就是d(或j)。c和d(或者i和j)是对于VGG16而言的,f、l是对于ResNet而言的。e或者k是遮挡了灵敏度的c或者i。一般我们只需要得到c和i就可以了。

实战篇

训练自己的数据集

这里我们给出数据集的链接: https://pan.baidu.com/s/1LH7wWR63tMNh-w2pUx_6Lg 提取码: 21mv
数据集下载下来之后,包含了训练数据集和测试数据集。它是检测梅西和足球的图像数据集。

现在我们将VOC格式的数据集变成YOLOV5的训练数据集和验证数据集

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
import random
from shutil import copyfile

# classes=["ball"]
classes = ["ball", "messi"]

TRAIN_RATIO = 80

def clear_hidden_files(path):
    dir_list = os.listdir(path)
    for i in dir_list:
        abspath = os.path.join(os.path.abspath(path), i)
        if os.path.isfile(abspath):
            if i.startswith("._"):
                os.remove(abspath)
        else:
            clear_hidden_files(abspath)

def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[1]) / 2.0
    y = (box[2] + box[3]) / 2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)

def convert_annotation(image_id):
    in_file = open('VOCdevkit/VOC2007/Annotations/%s.xml' % image_id)
    out_file = open('VOCdevkit/VOC2007/YOLOLabels/%s.txt' % image_id, 'w')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
    in_file.close()
    out_file.close()

wd = os.getcwd()
wd = os.getcwd()
data_base_dir = os.path.join(wd, "VOCdevkit/")
if not os.path.isdir(data_base_dir):
    os.mkdir(data_base_dir)
work_sapce_dir = os.path.join(data_base_dir, "VOC2007/")
if not os.path.isdir(work_sapce_dir):
    os.mkdir(work_sapce_dir)
annotation_dir = os.path.join(work_sapce_dir, "Annotations/")
if not os.path.isdir(annotation_dir):
    os.mkdir(annotation_dir)
clear_hidden_files(annotation_dir)
image_dir = os.path.join(work_sapce_dir, "JPEGImages/")
if not os.path.isdir(image_dir):
    os.mkdir(image_dir)
clear_hidden_files(image_dir)
yolo_labels_dir = os.path.join(work_sapce_dir, "YOLOLabels/")
if not os.path.isdir(yolo_labels_dir):
    os.mkdir(yolo_labels_dir)
clear_hidden_files(yolo_labels_dir)
yolov5_images_dir = os.path.join(data_base_dir, "images/")
if not os.path.isdir(yolov5_images_dir):
    os.mkdir(yolov5_images_dir)
clear_hidden_files(yolov5_images_dir)
yolov5_labels_dir = os.path.join(data_base_dir, "labels/")
if not os.path.isdir(yolov5_labels_dir):
    os.mkdir(yolov5_labels_dir)
clear_hidden_files(yolov5_labels_dir)
yolov5_images_train_dir = os.path.join(yolov5_images_dir, "train/")
if not os.path.isdir(yolov5_images_train_dir):
    os.mkdir(yolov5_images_train_dir)
clear_hidden_files(yolov5_images_train_dir)
yolov5_images_test_dir = os.path.join(yolov5_images_dir, "val/")
if not os.path.isdir(yolov5_images_test_dir):
    os.mkdir(yolov5_images_test_dir)
clear_hidden_files(yolov5_images_test_dir)
yolov5_labels_train_dir = os.path.join(yolov5_labels_dir, "train/")
if not os.path.isdir(yolov5_labels_train_dir):
    os.mkdir(yolov5_labels_train_dir)
clear_hidden_files(yolov5_labels_train_dir)
yolov5_labels_test_dir = os.path.join(yolov5_labels_dir, "val/")
if not os.path.isdir(yolov5_labels_test_dir):
    os.mkdir(yolov5_labels_test_dir)
clear_hidden_files(yolov5_labels_test_dir)

train_file = open(os.path.join(wd, "yolov5_train.txt"), 'w')
test_file = open(os.path.join(wd, "yolov5_val.txt"), 'w')
train_file.close()
test_file.close()
train_file = open(os.path.join(wd, "yolov5_train.txt"), 'a')
test_file = open(os.path.join(wd, "yolov5_val.txt"), 'a')
list_imgs = os.listdir(image_dir) # list image files
prob = random.randint(1, 100)
print("Probability: %d" % prob)
for i in range(0, len(list_imgs)):
    path = os.path.join(image_dir, list_imgs[i])
    if os.path.isfile(path):
        image_path = image_dir + list_imgs[i]
        voc_path = list_imgs[i]
        (nameWithoutExtention, extention) = os.path.splitext(os.path.basename(image_path))
        (voc_nameWithoutExtention, voc_extention) = os.path.splitext(os.path.basename(voc_path))
        annotation_name = nameWithoutExtention + '.xml'
        annotation_path = os.path.join(annotation_dir, annotation_name)
        label_name = nameWithoutExtention + '.txt'
        label_path = os.path.join(yolo_labels_dir, label_name)
    prob = random.randint(1, 100)
    print("Probability: %d" % prob)
    if(prob < TRAIN_RATIO): # train dataset
        if os.path.exists(annotation_path):
            train_file.write(image_path + '\n')
            convert_annotation(nameWithoutExtention) # convert label
            copyfile(image_path, yolov5_images_train_dir + voc_path)
            copyfile(label_path, yolov5_labels_train_dir + label_name)
    else: # test dataset
        if os.path.exists(annotation_path):
            test_file.write(image_path + '\n')
            convert_annotation(nameWithoutExtention) # convert label
            copyfile(image_path, yolov5_images_test_dir + voc_path)
            copyfile(label_path, yolov5_labels_test_dir + label_name)
train_file.close()
test_file.close()

划分完成后结果如下图

在YOLOV5的主目录下还多了两个yolov5_train.txt和yolov5_val.txt的文件。

在data目录下复制VOC.yaml,复制成VOC_bm.yaml,修改内容如下

# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ./VOCdevkit
train: # train images (relative to 'path')  16551 images
  - images/train/
val: # val images (relative to 'path')  4952 images
  - images/val/
test: # test images (optional)

# Classes
nc: 2  # number of classes
names: ['ball', 'messi'] # class names

在models目录下复制yolov5s.yaml,复制成yolov5s_bm.yaml,修改内容如下

nc: 2  # number of classes

训练

在train.py中添加运行参数

--data data/voc_bm.yaml --cfg models/yolov5s_bm.yaml --weights
weights/yolov5s.pt --batch-size 16 --epochs 100 --workers 4 --name yolov5sbaseline

训练过程可视化

tensorboard --logdir=./runs/train/yolov5sbaseline

然后打开http://localhost:6006/

测试

在detect.py中添加运行参数

--source ./VOCdevkit/images/val/img00003.jpg --weights runs/train/yolov5sbaseline/weights/best.pt

性能统计

在val.py中添加运行参数

--data data/voc_bm.yaml --weights runs/train/yolov5sbaseline/weights/best.pt --batch-size 16

运行结果

Class     Images  Instances          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 19/19 [00:01<00:00, 10.70it/s]
all        299        317      0.901      0.942      0.948       0.69
ball        299        290      0.973      0.983      0.988       0.81
messi        299         27      0.829      0.901      0.909       0.57

热力图可视化

在models/yolo.py中修改Detect类的forward方法代码如下

def forward(self, x):
    z = []  # inference output
    logits_ = []
    # 输入的x是来自三层金字塔的预测结果(n, 255, 80, 80),(n, 255, 40, 40),(n, 255, 20, 20)
    for i in range(self.nl):
        # 对三种类型的anchor特征图进行卷积
        x[i] = self.m[i](x[i])  # conv
        # 获取batch_size和中心点坐标值
        bs, _, ny, nx = map(int, x[i].shape)  # x(bs,255,20,20) to x(bs,3,20,20,85)
        # 将achor的输出维度调整到最后
        # (n, 255, _, _) -> (n, 3, nc+5, ny, nx) -> (n, 3, ny, nx, nc+5)
        # 相当于三层分别预测了80*80、40*40、20*20,每一次预测都包含3个框
        x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

        if not self.training:  # 推理
            if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                # 为每一层构造网格
                self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
            # 只保留分类数,去掉anchor输出的维度
            logits = x[i][..., 5:]
            y = x[i].sigmoid()
            # 改变原数据
            if self.inplace:
                # grid[i] = (3, 20, 20, 2), y = [n, 3, 20, 20, nc+5]
                # grid实际是 位置基准 或者理解为 cell的预测初始位置,而y[..., 0:2]是作为在grid坐标基础上的位置偏移
                # anchor_grid实际是 预测框基准 或者理解为 预测框的初始位置,而 y[..., 2:4]是作为预测框位置的调整
                y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
            else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)
                # stride应该是一个grid cell的实际尺寸
                # 经过sigmoid,值范围变成了(0-1),下一行代码将值变成范围(-0.5,1.5),
                # 相当于预选框上下左右都扩大了0.5倍的移动区域,不易大于0.5倍,否则就重复检验了其他网格的内容了
                # 此处的1表示一个grid cell的尺寸,尽量让预测框的中心在grid cell中心附近
                xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xy
                # 范围变成(0-4)倍,设置为4倍的原因是下层的感受野是上层的2倍
                # 因下层注重检测大目标,相对比上层而言,计算量更小,4倍是一个折中的选择
                wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                y = torch.cat((xy, wh, y[..., 4:]), -1)
            # z.append(y.view(bs, -1, self.no))
            z.append(y.view(-1, int(y.size(1) * y.size(2) * y.size(3)), self.no))
            # (n, 3, ny, nx, nc) -> (n, 3*ny*nx, nc)并将每一层的分类值添加到列表中
            logits_.append(logits.view(bs, -1, self.no - 5))

    # return x if self.training else (torch.cat(z, 1), x)
    return x if self.training else (torch.cat(z, 1), torch.cat(logits_, 1), x)

未修改之前的代码可以参考YOLO系列介绍(二) 中的YOLOV5的代码解析。然后是一个对box处理的代码

from typing import Union, Iterable, Sequence
import numpy as np
from enum import Enum


class Point:
    class PointSource(Enum):
        Torch = 'Torch'
        TF = "TF"
        CV = 'CV'
        Numpy = 'Numpy'

    @staticmethod
    def point2point(point, in_source, to_source, in_relative=None, to_relative=None, shape=None, shape_source=None):
        if point is None or len(point) == 0:
            pass
        elif isinstance(point[0], (tuple, list, np.ndarray)):
            point = [Point._point2point(p, in_source=in_source, to_source=to_source,
                                        in_relative=in_relative, to_relative=to_relative,
                                        shape=shape, shape_source=shape_source) for p in point]
        else:
            point = Point._point2point(point, in_source=in_source, to_source=to_source,
                                       in_relative=in_relative, to_relative=to_relative,
                                       shape=shape, shape_source=shape_source)
        return point

    @staticmethod
    def _point2point(point, in_source, to_source, in_relative=None, to_relative=None, shape=None, shape_source=None):
        if isinstance(in_source, Point.PointSource):
            in_source = in_source.value
        if isinstance(to_source, Point.PointSource):
            to_source = to_source.value

        if (in_source in [Point.PointSource.Torch.value, Point.PointSource.CV.value] and to_source in [
            Point.PointSource.TF.value, Point.PointSource.Numpy.value]) \
                or (in_source in [Point.PointSource.TF.value, Point.PointSource.Numpy.value] and to_source in [
            Point.PointSource.Torch.value, Point.PointSource.CV.value]):
            point = (point[1], point[0])
        elif (in_source is None and to_source is None) or in_source == to_source \
                or (in_source in [Point.PointSource.Torch.value, Point.PointSource.CV.value] and to_source in [
            Point.PointSource.CV.value, Point.PointSource.Torch.value]) \
                or (in_source in [Point.PointSource.TF.value, Point.PointSource.Numpy.value] and to_source in [
            Point.PointSource.TF.value, Point.PointSource.Numpy.value]):
            pass
        else:
            raise Exception(
                f'Conversion form {in_source} to {to_source} is not Supported.'
                f' Supported types: {Box._get_enum_names(Point.PointSource)}')
        if to_source is not None and shape_source is not None and shape is not None:
            img_w, img_h = Point.point2point(shape, in_source=shape_source, to_source=to_source)
            if not in_relative and to_relative:
                p1, p2 = point
                point = [p1 / img_w, p2 / img_h]
            elif in_relative and not to_relative:
                p1, p2 = point
                point = [p1 * img_w, p2 * img_h]
        return point

    @staticmethod
    def _put_point(img, point, radius, color=(0, 255, 0), thickness=None, lineType=None, shift=None, in_source="Numpy"):
        import cv2
        if not isinstance(point, int):
            point = (int(point[0]), int(point[1]))
        point = Point.point2point(point, in_source=in_source, to_source="CV")
        return cv2.circle(img, point, radius, color, thickness, lineType, shift)

    @staticmethod
    def put_point(img, point, radius, color=(0, 255, 0), thickness=None, lineType=None, shift=None, in_source="Numpy"):
        if point is None or len(point) == 0:
            pass
        elif isinstance(point[0], (tuple, list, np.ndarray)):
            for p in point:
                img = Point._put_point(img, p, radius, color, thickness, lineType, shift, in_source)
        else:
            img = Point._put_point(img, point, radius, color, thickness, lineType, shift, in_source)
        return img

    @staticmethod
    def sort_points(pts: Union[list, tuple]):
        """
        Sort a list of 4 points based on upper-left, upper-right, down-right, down-left
        :param pts:
        :return:
        """
        top_points = sorted(pts, key=lambda l: l[0])[:2]
        top_left = min(top_points, key=lambda l: l[1])
        top_right = max(top_points, key=lambda l: l[1])
        pts.remove(top_left)
        pts.remove(top_right)
        down_left = min(pts, key=lambda l: l[1])
        down_right = max(pts, key=lambda l: l[1])
        return top_left, top_right, down_right, down_left


class Box:
    class BoxFormat(Enum):
        XYWH = "XYWH"
        XYXY = "XYXY"
        XCYC = "XCYC"

    class BoxSource(Enum):
        Torch = 'Torch'
        TF = "TF"
        CV = 'CV'
        Numpy = 'Numpy'

    class OutType(Enum):
        Numpy = np.array
        List = list
        Tuple = tuple

    @staticmethod
    def box2box(box,
                in_format=None,
                to_format=None,
                in_source=BoxSource.Numpy,
                to_source=BoxSource.Numpy,
                in_relative=None,
                to_relative=None,
                shape=None,
                shape_source=None,
                out_type=None,
                return_int=None):
        if box is None or len(box) == 0:
            pass
        elif isinstance(box[0], (tuple, list, np.ndarray)):
            box = [Box._box2box(b, in_format=in_format, to_format=to_format, in_source=in_source, to_source=to_source,
                                in_relative=in_relative, to_relative=to_relative, shape=shape,
                                shape_source=shape_source, out_type=out_type, return_int=return_int)
                   for b in box]

        else:
            box = Box._box2box(box, in_format=in_format, to_format=to_format, in_source=in_source, to_source=to_source,
                               in_relative=in_relative, to_relative=to_relative, shape=shape,
                               shape_source=shape_source,
                               out_type=out_type, return_int=return_int)
        return box

    @staticmethod
    def _box2box(box,
                 in_format=None,
                 to_format=None,
                 in_source=None,
                 to_source=None,
                 in_relative=None,
                 to_relative=None,
                 shape=None,
                 shape_source=None,
                 out_type=None,
                 return_int=None):
        """

        :param box:
        :param in_format:
        :param to_format:
        :param in_source:
        :param to_source:
        :param relative:
        :param img_w:
        :param img_h:
        :param out_type: output type of the box. Supported types: list, tuple, numpy
        :return:
        """
        if isinstance(in_format, Box.BoxFormat):
            in_format = in_format.value
        if isinstance(to_format, Box.BoxFormat):
            to_format = to_format.value

        if isinstance(in_source, Box.BoxSource):
            in_source = in_source.value
        if isinstance(to_source, Box.BoxSource):
            to_source = to_source.value

        if in_format == Box.BoxFormat.XYWH.value and to_format == Box.BoxFormat.XYXY.value:
            x1, y1, w, h = box
            x2, y2 = x1 + w, y1 + h
            box = [x1, y1, x2, y2]
        elif in_format == Box.BoxFormat.XYXY.value and to_format == Box.BoxFormat.XYWH.value:
            x1, y1, x2, y2 = box
            w, h = x2 - x1, y2 - y1
            box = [x1, y1, w, h]
        elif in_format == Box.BoxFormat.XYXY.value and to_format == Box.BoxFormat.XCYC.value:
            x1, y1, x2, y2 = box
            w, h = x2 - x1, y2 - y1
            xc, yc = (x1 + x2) / 2, (y1 + y2) / 2
            box = [xc, yc, w, h]
        elif in_format == Box.BoxFormat.XCYC.value and to_format == Box.BoxFormat.XYXY.value:
            xc, yc, w, h = box
            x1, y1, x2, y2 = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
            box = [x1, y1, x2, y2]
        elif in_format == Box.BoxFormat.XYWH.value and to_format == Box.BoxFormat.XCYC.value:
            x1, y1, w, h = box
            x2, y2 = x1 + w, y1 + h
            xc, yc = (x1 + x2) / 2, (y1 + y2) / 2
            box = [xc, yc, w, h]
        elif in_format == Box.BoxFormat.XCYC.value and to_format == Box.BoxFormat.XYWH.value:
            xc, yc, w, h = box
            x1, y1 = xc - w // 2, yc - h // 2

            box = [x1, y1, w, h]
        elif (in_format is None and to_format is None) or in_format == to_format:
            pass
        else:
            raise Exception(
                f'Conversion form {in_format} to {to_format} is not Supported.'
                f' Supported types: {Box._get_enum_names(Box.BoxFormat)}')

        if (in_source in [Box.BoxSource.Torch.value, Box.BoxSource.CV.value] and to_source in [
            Box.BoxSource.TF.value, Box.BoxSource.Numpy.value]) \
                or (in_source in [Box.BoxSource.TF.value, Box.BoxSource.Numpy.value] and to_source in [
            Box.BoxSource.Torch.value, Box.BoxSource.CV.value]):
            box = [box[1], box[0], box[3], box[2]]
        elif (in_source is None and to_source is None) or in_source == to_source \
                or (in_source in [Box.BoxSource.Torch.value, Box.BoxSource.CV.value] and to_source in [
            Box.BoxSource.CV.value, Box.BoxSource.Torch.value]) \
                or (in_source in [Box.BoxSource.TF.value, Box.BoxSource.Numpy.value] and to_source in [
            Box.BoxSource.TF.value, Box.BoxSource.Numpy.value]):
            pass
        else:
            raise Exception(
                f'Conversion form {in_source} to {to_source} is not Supported.'
                f' Supported types: {Box._get_enum_names(Box.BoxSource)}')
        if to_source is not None and shape_source is not None and shape is not None:
            img_w, img_h = Point.point2point(shape, in_source=shape_source, to_source=to_source)
            if not in_relative and to_relative:
                b1, b2, b3, b4 = box
                box = [b1 / img_w, b2 / img_h, b3 / img_w, b4 / img_h]
            elif in_relative and not to_relative:
                b1, b2, b3, b4 = box
                box = [b1 * img_w, b2 * img_h, b3 * img_w, b4 * img_h]

        box = Box.get_type(box, out_type)
        if return_int:
            box = [int(b) for b in box]
        return box

    @staticmethod
    def get_type(in_, out_type):
        if out_type is not None:
            try:
                in_ = out_type(in_)
            except:
                raise Exception(
                    f'{out_type} is not Supported. Supported types: {Box._get_enum_names(Box.OutType)}')
        return in_

    @staticmethod
    def _get_enum_names(in_):
        return [n.name for n in in_]

    @staticmethod
    def _put_box(img, box, copy=False,
                 color=(0, 255, 0),
                 thickness=1,
                 lineType=None,
                 shift=None,
                 in_relative=False,
                 in_format="XYXY",
                 in_source='Numpy'):
        import cv2
        box = Box.box2box(box,
                          in_format=in_format,
                          to_format=Box.BoxFormat.XYXY,
                          in_source=in_source,
                          to_source=Box.BoxSource.CV,
                          in_relative=in_relative,
                          to_relative=False,
                          shape=img.shape[:2],
                          shape_source='Numpy')

        if not isinstance(img, np.ndarray):
            img = np.array(img).astype(np.uint8)
        else:
            img = img.astype(np.uint8)
        box = [int(point) for point in box]
        if copy:
            img = img.copy()
        img = cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, thickness)
        #img = cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color=color, thickness=thickness,
        #                    lineType=lineType,
        #                    shift=shift)

        return img

    @staticmethod
    def put_box(img, box, copy=False,
                color=(0, 255, 0),
                thickness=1,
                lineType=None,
                shift=None,
                in_relative=False,
                in_format=BoxFormat.XYXY,
                in_source=BoxSource.Numpy):
        if box is None or len(box) == 0:
            pass
        elif isinstance(box[0], (tuple, list, np.ndarray)):
            for b in box:
                img = Box._put_box(img, box=b, copy=copy, color=color, thickness=thickness, lineType=lineType,
                                   shift=shift, in_format=in_format, in_source=in_source, in_relative=in_relative)
        else:
            img = Box._put_box(img, box=box, copy=copy, color=color, thickness=thickness, lineType=lineType,
                               shift=shift, in_format=in_format, in_source=in_source, in_relative=in_relative)

        return img

    @staticmethod
    def _get_box_img(img, bbox, box_format=BoxFormat.XYXY, box_source=BoxSource.Numpy):
        bbox = Box.box2box(bbox, in_format=box_format, to_format=Box.BoxFormat.XYXY, in_source=box_source,
                           to_source=Box.BoxSource.Numpy, return_int=True)
        img_part = img[bbox[0]:bbox[2], bbox[1]:bbox[3]]
        return img_part

    @staticmethod
    def get_box_img(img, bbox, box_format=BoxFormat.XYXY, box_source=BoxSource.Numpy):
        if len(img.shape) != 3:
            raise Exception('The image size should be 3')

        img_part = []
        if bbox is None or len(bbox) == 0:
            pass
        elif isinstance(bbox[0], (tuple, list, np.ndarray)):
            img_part = [Box._get_box_img(img, b, box_format, box_source) for b in bbox]
        else:
            img_part = Box._get_box_img(img, bbox, box_format, box_source)
        return img_part

    @staticmethod
    def _put_text(img, text, org, fontFace=None, fontScale=1, color=(0, 255, 0),
                  thickness=1, lineType=None, bottomLeftOrigin=None, org_source='Numpy'):
        import cv2
        org = (int(org[0]), int(org[1]))
        org = Point.point2point(org, in_source=org_source, to_source=Point.PointSource.CV)
        font_face = cv2.FONT_HERSHEY_PLAIN if fontFace is None else fontFace
        #img = cv2.putText(img, text, org, font_face, fontScale, color, thickness, lineType, bottomLeftOrigin)
        img = cv2.putText(img, text, org, font_face, fontScale, color, thickness, 0, 0)

        return img

    @staticmethod
    def put_text(img,
                 text,
                 org,
                 fontFace=None,
                 fontScale: float = 1,
                 color=(0, 255, 0),
                 thickness=1,
                 lineType=None,
                 bottomLeftOrigin=None,
                 org_source='Numpy'):
        if text is None or len(text) == 0 or org is None or len(org) == 0:
            pass
        elif isinstance(text, (tuple, list, np.ndarray)):
            for t, o in zip(text, org):
                img = Box._put_text(img, t, o, fontFace, fontScale, color, thickness, lineType, bottomLeftOrigin,
                                    org_source=org_source)
        else:
            img = Box._put_text(img, text, org, fontFace, fontScale, color, thickness, lineType, bottomLeftOrigin,
                                org_source=org_source)
        return img

    @staticmethod
    def get_biggest(box,
                    in_format=BoxFormat.XYXY,
                    in_source=BoxSource.Numpy,
                    get_index=False,
                    inputs: Union[None, dict] = None,
                    reverse=False):
        if len(box) == 0 or box is None:
            return
        box = Box.box2box(box,
                          in_format=in_format,
                          in_source=in_source,
                          to_source=Box.BoxSource.Numpy,
                          to_format=Box.BoxFormat.XYWH
                          )
        if reverse:
            chosen_box = min(box, key=lambda b: b[2] * b[3])
        else:
            chosen_box = max(box, key=lambda b: b[2] * b[3])
        index = box.index(chosen_box)
        if inputs is not None:
            inputs = {k: v[index] for k, v in inputs.items()}
            return inputs
        chosen_box = Box.box2box(chosen_box, in_format=Box.BoxFormat.XYWH, to_format=Box.BoxFormat.XYXY)
        if get_index:
            return chosen_box, index
        return chosen_box

    @staticmethod
    def get_area(box,
                 in_format=BoxFormat.XYXY,
                 in_source=BoxSource.Numpy):
        box = Box.box2box(box,
                          in_format=in_format,
                          in_source=in_source,
                          to_source=Box.BoxSource.Numpy,
                          to_format=Box.BoxFormat.XYWH
                          )
        area = box[2] * box[3]
        return area

    @staticmethod
    def fill_box(img,
                 box,
                 value,
                 in_format=BoxFormat.XYXY,
                 in_source=BoxSource.Numpy):
        """
        Fill the selected box with the specified value.
        :param img: The input image
        :param box: the box that should be filled
        :param value: the value with which the box will be filled
        :param in_format: box input format
        :param in_source: box input source
        :return: the filled box
        """
        bbox = Box.box2box(box,
                           in_format=in_format,
                           in_source=in_source,
                           to_source=Box.BoxSource.Numpy,
                           to_format=Box.BoxFormat.XYXY
                           )
        img[bbox[0]:bbox[2], bbox[1]:bbox[3]] = value
        return img

    @staticmethod
    def fill_outer_box(img,
                       box,
                       value: int = 0,
                       in_format=BoxFormat.XYXY,
                       in_source=BoxSource.Numpy):
        """
        Fill the outer area of the selected box with the specified value.
        :param img: The input image
        :param box: the box that should remain fixed
        :param value: the value with which the outer box will be filled, default is zero
        :param in_format: box input format
        :param in_source: box input source
        :return: the filled box
        """
        import cv2
        bbox = Box.box2box(box,
                           in_format=in_format,
                           in_source=in_source,
                           to_source=Box.BoxSource.Numpy,
                           to_format=Box.BoxFormat.XYXY
                           )
        mask = np.ones_like(img, dtype=np.uint8) * value
        mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
        img = cv2.multiply(img, mask)
        return img

    @staticmethod
    def _put_box_text(img, box, label, color=(128, 128, 128), txt_color=(255, 255, 255), thickness=2):
        import cv2
        p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
        img = Box.put_box(img, box, color=color, thickness=thickness, lineType=cv2.LINE_AA, in_source=Box.BoxSource.CV)
        text_font = max(thickness - 1, 1)  # font thickness
        w, h = cv2.getTextSize(label, 0, fontScale=thickness / 3, thickness=text_font)[0]  # text width, height
        outside = p1[1] - h - 3 >= 0  # label fits outside box
        p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
        img = Box.put_box(img, [*p1, *p2], color=color, thickness=-1, lineType=cv2.LINE_AA, in_source=Box.BoxSource.CV)
        x0, x1 = p1[0], p1[1] - 2 if outside else p1[1] + h + 2
        img = Box.put_text(img, label, (x0, x1), fontFace=0, fontScale=thickness / 3, color=txt_color, thickness=text_font,
                           lineType=cv2.LINE_AA, org_source="CV")
        return img

    @staticmethod
    def put_box_text(img: Union[Sequence, np.ndarray], box: Union[Sequence], label: Union[Sequence, str],
                     color=(128, 128, 128), txt_color=(255, 255, 255), thickness=2):
        """
            :param img:
            :param box: It should be in numpy source!
            :param label:
            :param color:
            :param txt_color:
            :param thickness:
            :return:
            """
        box = Box.box2box(box, in_source=Box.BoxSource.Numpy, to_source=Box.BoxSource.CV)
        if isinstance(box, Sequence) and isinstance(box[0], Sequence) and isinstance(label, Sequence):
            if isinstance(color, Sequence) and isinstance(color[0], Sequence):
                if isinstance(txt_color, Sequence) and isinstance(txt_color[0], Sequence):
                    for b, l, c, t_c in zip(box, label, color, txt_color):
                        img = Box._put_box_text(img, b, l, c, t_c, thickness)
                else:
                    for b, l, c in zip(box, label, color):
                        img = Box._put_box_text(img, b, l, c, txt_color, thickness)
            else:
                for b, l in zip(box, label):
                    img = Box._put_box_text(img, b, l, color, txt_color, thickness)
        else:
            img = Box._put_box_text(img, box, label, color, txt_color, thickness)
        return img


if __name__ == '__main__':
    print(Box.BoxFormat.XYXY is Box.BoxFormat)

新建一个模型文件

import numpy as np
from utils.boxes import Box
import torch
from models.experimental import attempt_load
from utils.general import xywh2xyxy
from utils.datasets import letterbox
import cv2
import time
import torchvision
import torch.nn as nn
from utils.metrics import box_iou

class YOLOV5TorchObjectDetector(nn.Module):
    def __init__(self,
                 model_weight,
                 device,
                 img_size,
                 names=None,
                 mode='eval',
                 confidence=0.4,
                 iou_thresh=0.45,
                 agnostic_nms=False):
        super(YOLOV5TorchObjectDetector, self).__init__()
        self.device = device
        self.model = None
        self.img_size = img_size
        self.mode = mode
        self.confidence = confidence
        self.iou_thresh = iou_thresh
        self.agnostic = agnostic_nms
        self.model = attempt_load(model_weight, map_location=device, inplace=False, fuse=False)
        print("[INFO] Model is loaded")
        self.model.requires_grad_(True)
        self.model.to(device)

        if self.mode == 'train':
            self.model.train()
        else:
            self.model.eval()
        # fetch the names
        if names is None:
            print('[INFO] fetching names from coco file')

            self.names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
                          'traffic light',
                          'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
                          'cow',
                          'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
                          'frisbee',
                          'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
                          'surfboard',
                          'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
                          'apple',
                          'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
                          'couch',
                          'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
                          'keyboard', 'cell phone',
                          'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
                          'teddy bear',
                          'hair drier', 'toothbrush']
        else:
            self.names = names

        # preventing cold start
        img = torch.zeros((1, 3, *self.img_size), device=device)
        self.model(img)

    @staticmethod
    def non_max_suppression(prediction, logits, conf_thres=0.3, iou_thres=0.45, classes=None, agnostic=False,
                            multi_label=False, labels=(), max_det=300):
        """Runs Non-Maximum Suppression (NMS) on inference and logits results

        Returns:
             list of detections, on (n,6) tensor per image [xyxy, conf, cls] and pruned input logits (n, number-classes)
        """

        nc = prediction.shape[2] - 5  # number of classes
        xc = prediction[..., 4] > conf_thres  # candidates

        # Checks
        assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
        assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

        # Settings
        min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
        max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
        time_limit = 10.0  # seconds to quit after
        redundant = True  # require redundant detections
        multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
        merge = False  # use merge-NMS

        t = time.time()
        output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
        logits_output = [torch.zeros((0, nc), device=logits.device)] * logits.shape[0]
        #logits_output = [torch.zeros((0, 80), device=logits.device)] * logits.shape[0]
        for xi, (x, log_) in enumerate(zip(prediction, logits)):  # image index, image inference
            # Apply constraints
            #x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
            x = x[xc[xi]]  # confidence
            log_ = log_[xc[xi]]
            # Cat apriori labels if autolabelling
            if labels and len(labels[xi]):
                l = labels[xi]
                v = torch.zeros((len(l), nc + 5), device=x.device)
                v[:, :4] = l[:, 1:5]  # box
                v[:, 4] = 1.0  # conf
                v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
                x = torch.cat((x, v), 0)

            # If none remain process next image
            if not x.shape[0]:
                continue

            # Compute conf
            x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
            # Box (center x, center y, width, height) to (x1, y1, x2, y2)
            box = xywh2xyxy(x[:, :4])

            # Detections matrix nx6 (xyxy, conf, cls)
            if multi_label:
                i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
                x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
            else:  # best class only
                conf, j = x[:, 5:].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
                log_ = log_[conf.view(-1) > conf_thres]
            # Filter by class
            if classes is not None:
                x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

            # Check shape
            n = x.shape[0]  # number of boxes
            if not n:  # no boxes
                continue
            elif n > max_nms:  # excess boxes
                x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

            # Batched NMS
            c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
            boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
            i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
            if i.shape[0] > max_det:  # limit detections
                i = i[:max_det]
            if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
                # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
                iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
                weights = iou * scores[None]  # box weights
                x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
                if redundant:
                    i = i[iou.sum(1) > 1]  # require redundancy

            output[xi] = x[i]
            logits_output[xi] = log_[i]
            assert log_[i].shape[0] == x[i].shape[0]
            if (time.time() - t) > time_limit:
                print(f'WARNING: NMS time limit {time_limit}s exceeded')
                break  # time limit exceeded

        return output, logits_output

    @staticmethod
    def yolo_resize(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):

        return letterbox(img, new_shape=new_shape, color=color, auto=auto, scaleFill=scaleFill, scaleup=scaleup)

    def forward(self, img):
        prediction, logits, _ = self.model(img, augment=False)
        prediction, logits = self.non_max_suppression(prediction, logits, self.confidence, self.iou_thresh,
                                                      classes=None,
                                                      agnostic=self.agnostic)
        self.boxes, self.class_names, self.classes, self.confidences = [[[] for _ in range(img.shape[0])] for _ in
                                                                        range(4)]
        for i, det in enumerate(prediction):  # detections per image
            if len(det):
                for *xyxy, conf, cls in det:
                    bbox = Box.box2box(xyxy,
                                       in_source=Box.BoxSource.Torch,
                                       to_source=Box.BoxSource.Numpy,
                                       return_int=True)
                    self.boxes[i].append(bbox)
                    self.confidences[i].append(round(conf.item(), 2))
                    cls = int(cls.item())
                    self.classes[i].append(cls)
                    if self.names is not None:
                        self.class_names[i].append(self.names[cls])
                    else:
                        self.class_names[i].append(cls)
        return [self.boxes, self.classes, self.class_names, self.confidences], logits

    def preprocessing(self, img):
        if len(img.shape) != 4:
            img = np.expand_dims(img, axis=0)
        im0 = img.astype(np.uint8)
        img = np.array([self.yolo_resize(im, new_shape=self.img_size)[0] for im in im0])
        img = img.transpose((0, 3, 1, 2)) 
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(self.device)
        img = img / 255.0
        return img


if __name__ == '__main__':
    model_path = './weights/yolov5s.pt'
    img_path = './images/eagle.jpg'
    model = YOLOV5TorchObjectDetector(model_path, 'cpu', img_size=(640, 640)).to('cpu')
    img = np.expand_dims(cv2.imread(img_path)[..., ::-1], axis=0)
    img = model.preprocessing(img)
    a = model(img)
    print(model._modules)

新增加一个gradcam.py

import time
import torch
import torch.nn.functional as F


def find_yolo_layer(model, layer_name):
    """Find yolov5 layer to calculate GradCAM and GradCAM++

    Args:
        model: yolov5 model.
        layer_name (str): the name of layer with its hierarchical information.

    Return:
        target_layer: found layer
    """
    hierarchy = layer_name.split('_')
    target_layer = model.model._modules[hierarchy[0]]

    for h in hierarchy[1:]:
        target_layer = target_layer._modules[h]
    return target_layer

class YOLOV5GradCAM:

    def __init__(self, model, layer_name, img_size=(640, 640)):
        self.model = model
        self.gradients = dict()
        self.activations = dict()

        def backward_hook(module, grad_input, grad_output):
            self.gradients['value'] = grad_output[0]
            return None

        def forward_hook(module, input, output):
            self.activations['value'] = output
            return None

        target_layer = find_yolo_layer(self.model, layer_name)
        target_layer.register_forward_hook(forward_hook)
        #target_layer.register_backward_hook(backward_hook)
        target_layer.register_full_backward_hook(backward_hook)

        device = 'cuda' if next(self.model.model.parameters()).is_cuda else 'cpu'
        self.model(torch.zeros(1, 3, *img_size, device=device))
        print('[INFO] saliency_map size :', self.activations['value'].shape[2:])

    def forward(self, input_img, class_idx=True):
        """
        Args:
            input_img: input image with shape of (1, 3, H, W)
        Return:
            mask: saliency map of the same spatial dimension with input
            logit: model output
            preds: The object predictions
        """
        saliency_maps = []
        b, c, h, w = input_img.size()
        tic = time.time()
        preds, logits = self.model(input_img)
        print("[INFO] model-forward took: ", round(time.time() - tic, 4), 'seconds')
        for logit, cls, cls_name in zip(logits[0], preds[1][0], preds[2][0]):
            if class_idx:
                score = logit[cls]
            else:
                score = logit.max()
            self.model.zero_grad()
            tic = time.time()
            score.backward(retain_graph=True)
            print(f"[INFO] {cls_name}, model-backward took: ", round(time.time() - tic, 4), 'seconds')
            gradients = self.gradients['value']
            activations = self.activations['value']
            b, k, u, v = gradients.size()
            alpha = gradients.view(b, k, -1).mean(2)
            weights = alpha.view(b, k, 1, 1)
            saliency_map = (weights * activations).sum(1, keepdim=True)
            saliency_map = F.relu(saliency_map)
            #saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
            saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
            saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
            saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
            saliency_maps.append(saliency_map)
        return saliency_maps, logits, preds

    def __call__(self, input_img):
        return self.forward(input_img)

最后新增加一个main.py

import os
import time
import argparse
import numpy as np
from models.gradcam import YOLOV5GradCAM
from models.yolo_v5_object_detector import YOLOV5TorchObjectDetector
import cv2
from utils.boxes import Box 

def split_extension(path, extension=None, suffix=None):
    remain, extension_ = os.path.splitext(path)
    if extension and suffix:
        return remain + suffix + extension
    elif extension is None and suffix:
        return remain + suffix + extension_
    elif extension:
        return remain + extension
    return remain, extension_
    
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', type=str, default="yolov5s.pt", help='Path to the model')
parser.add_argument('--img-path', type=str, default='images/', help='input image path')
parser.add_argument('--output-dir', type=str, default='outputs', help='output dir')
parser.add_argument('--img-size', type=int, default=640, help="input image size")
parser.add_argument('--target-layer', type=str, default='model_23_cv3_act',
                    help='The layer hierarchical address to which gradcam will applied,'
                         ' the names should be separated by underline')
parser.add_argument('--method', type=str, default='gradcam', help='gradcam method')
parser.add_argument('--device', type=str, default='cpu', help='cuda or cpu')
parser.add_argument('--names', type=str, default=None,
                    help='The name of the classes. The default is set to None and is set to coco classes. Provide your custom names as follow: object1,object2,object3')
parser.add_argument('--no_text_box', action='store_true',
                        help='do not show label and box on the heatmap')
args = parser.parse_args()


def get_res_img(bbox, mask, res_img):
    mask = mask.squeeze(0).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).detach().cpu().numpy().astype(
        np.uint8)
    heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
    #n_heatmat = (Box.fill_outer_box(heatmap, bbox) / 255).astype(np.float32)
    n_heatmat = (heatmap / 255).astype(np.float32)
    res_img = res_img / 255
    res_img = cv2.add(res_img, n_heatmat)
    res_img = (res_img / res_img.max())
    return res_img, n_heatmat


def put_text_box(bbox, cls_name, res_img, no_text_box):
    x1, y1, x2, y2 = bbox
    # this is a bug in cv2. It does not put box on a converted image from torch unless it's buffered and read again!
    cv2.imwrite('temp.jpg', (res_img * 255).astype(np.uint8))
    res_img = cv2.imread('temp.jpg')
    if no_text_box != True:
        res_img = Box.put_box(res_img, bbox)
        res_img = Box.put_text(res_img, cls_name, (x1, y1))
    return res_img


def concat_images(images):
    w, h = images[0].shape[:2]
    width = w
    height = h * len(images)
    base_img = np.zeros((width, height, 3), dtype=np.uint8)
    for i, img in enumerate(images):
        base_img[:, h * i:h * (i + 1), ...] = img
    return base_img


def main(img_path):
    device = args.device
    input_size = (args.img_size, args.img_size)
    img = cv2.imread(img_path)
    print('[INFO] Loading the model')
    model = YOLOV5TorchObjectDetector(args.model_path, device, img_size=input_size,
                                      names=None if args.names is None else args.names.strip().split(","))
    torch_img = model.preprocessing(img[..., ::-1])
    if args.method == 'gradcam':
        saliency_method = YOLOV5GradCAM(model=model, layer_name=args.target_layer, img_size=input_size)
    tic = time.time()
    masks, logits, [boxes, _, class_names, _] = saliency_method(torch_img)
    print("total time:", round(time.time() - tic, 4))
    result = torch_img.squeeze(0).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).detach().cpu().numpy()
    result = result[..., ::-1]  # convert to bgr
    images = [result]
    for i, mask in enumerate(masks):
        res_img = result.copy()
        bbox, cls_name = boxes[0][i], class_names[0][i]
        res_img, heat_map = get_res_img(bbox, mask, res_img)
        res_img = put_text_box(bbox, cls_name, res_img, args.no_text_box)
        images.append(res_img)
    final_image = concat_images(images)
    img_name = split_extension(os.path.split(img_path)[-1], suffix='-res')
    output_path = f'{args.output_dir}/{img_name}'
    os.makedirs(args.output_dir, exist_ok=True)
    print(f'[INFO] Saving the final image at {output_path}')
    cv2.imwrite(output_path, final_image)


if __name__ == '__main__':
    if os.path.isdir(args.img_path):
        img_list = os.listdir(args.img_path)
        print(img_list)
        for item in img_list:
            main(os.path.join(args.img_path, item))
    else:
        main(args.img_path)

运行结果

展开阅读全文
加载中

作者的其它热门文章

打赏
0
0 收藏
分享
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部