加入收藏

yolov5实战之模型剪枝_环球热点

2023-06-28 03:22:23 来源:博客园

续yolov5实战之二维码检测

目录前沿为什么要做轻量化什么是剪枝稀疏化训练剪枝微调结语模型下载前沿

在上一篇yolov5的博客中,我们用yolov5训练了一个二维码检测器,可以用来检测图像中是否有二维码,后续可以接一个二维码解码器,就可以解码出二维码的信息了(后续可以聊聊)。这篇博客再讲讲另一个方面:模型轻量化,具体的是轻量化中的模型剪枝。


(资料图片仅供参考)

为什么要做轻量化

我们训练的模型不仅仅会用在GPU这种算力高的硬件上,也有可能用在嵌入式CPU或者NPU上,这类硬件算力往往较低,尽管在这些设备上运行模型时,我们可以将模型量化为int8,可以大大降低计算量,但有时候只靠这一方式也是不够的。比较直观能想到的提升模型运行速度的方式是裁剪模型,比如减少通道数或模型的深度,这种方式是以牺牲模型精度为代价的。这就促使我们寻找更好的模型轻量化方法,剪枝就是一种使用比较广泛的模型轻量化方法。

什么是剪枝

模型剪枝(Model Pruning)是一种通过减少神经网络模型中的冗余参数和连接来优化模型的方法。它旨在减小模型的大小、内存占用和计算复杂度,同时尽可能地保持模型的性能。

模型剪枝的基本思想是通过识别和删除对模型性能影响较小的参数或连接,以达到模型精简和优化的目的。方法包括剪枝后的参数微调、重新训练和微调整体网络结构等。直观的理解就是像下图这样。  模型剪枝可以在不显著损失模型性能的情况下,大幅度减少模型的参数量和计算量,从而提高模型的部署效率和推理速度。它特别适用于嵌入式设备、移动设备和边缘计算等资源受限的场景,以及需要部署在较小存储空间或带宽受限环境中的应用。本文选择的模型剪枝方法:Learning Efficient Convolutional Networks through Network Slimming源代码:https://github.com/foolwood/pytorch-slimming这个方法基于的想法是通过稀疏化训练,通过BN层的参数,自动得到权重较小通道,去掉这些通道,从而达到模型裁剪的目的。

稀疏化训练

如上文述,为了达到剪枝的目的,我们要使用稀疏化训练,以使得让模型权重更紧凑,能够去掉一些权重较小的通道,达到模型裁剪的目的。为了进行稀疏化训练,引入一个稀疏化稀疏参数,这个参数越大,模型越稀疏,能够裁剪的比例越大,需要在实际中调整,参数过大,模型性能可能会下降较多,参数过小,能够裁剪的比例又会过小。  为了进行稀疏化训练,首先汇总模型的所有BN层:

if opt.sl > 0:        print("Sparse Learning Model!")        print("===> Sparse learning rate is ", hyp["sl"])        prunable_modules = []        prunable_module_type = (nn.BatchNorm2d, )        for i, m in enumerate(model.modules()):            if isinstance(m, prunable_module_type):                prunable_modules.append(m)

在训练loss中增加稀疏化loss:

def compute_pruning_loss(p, prunable_modules, model, loss):    """    Compute the pruning loss    :param p: predicted output    :param prunable_modules: list of prunable modules    :param model: model    :param loss: original yolo loss    :return: loss    """    float_tensor = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor    sl_loss = float_tensor([0])    hyp = model.hyp  # hyperparameters    red = "mean"  # Loss reduction (sum or mean)    if prunable_modules is not None:        for m in prunable_modules:            sl_loss += m.weight.norm(1)        sl_loss /= len(prunable_modules)    sl_loss *= hyp["sl"]    bs = p[0].shape[0]  # batch size    loss += sl_loss * bs    return loss
# Forward            with amp.autocast(enabled=cuda):                pred = model(imgs)  # forward                loss, loss_items = compute_loss(pred, targets.to(device), model)  # loss scaled by batch_size                # Sparse Learning                if opt.sl > 0:                    loss = compute_pruning_loss(pred, prunable_modules, model, loss)                if rank != -1:                    loss *= opt.world_size  # gradient averaged between devices in DDP mode

设置合适的稀疏化稀疏进行训练,这一过程和普通的yolov5模型训练一样。

剪枝

pruning.py

#!/usr/bin/env python# -*- coding: utf-8 -*-"""Copyright (c) 2019 luozw, Inc. All Rights ReservedAuthors: luozhiwang(luozw1994@outlook.com)Date: 2020/9/7"""import osimport argparseimport numpy as npimport torchimport torch.nn as nnimport torch_pruning as tpimport copyimport matplotlib.pyplot as pltfrom models.yolo import Modelimport mathdef load_model(cfg="models/mobile-yolo5l_voc.yaml", weights="./outputs/mvoc/weights/best_mvoc.pt"):    restor_num = 0    ommit_num = 0    model = Model(cfg).to(device)    ckpt = torch.load(weights, map_location=device)  # load checkpoint    names = ckpt["model"].names    dic = {}    for k, v in ckpt["model"].float().state_dict().items():        if k in model.state_dict() and model.state_dict()[k].shape == v.shape:            dic[k] = v            restor_num += 1        else:            ommit_num += 1    print("Build model from", cfg)    print("Resotre weight from", weights)    print("Restore %d vars, ommit %d vars" % (restor_num, ommit_num))    ckpt["model"] = dic    model.load_state_dict(ckpt["model"], strict=False)       del ckpt    model.float()    model.model[-1].export = True    return model, namesdef bn_analyze(prunable_modules, save_path=None):    bn_val = []    max_val = []    for layer_to_prune in prunable_modules:        # select a layer        weight = layer_to_prune.weight.data.detach().cpu().numpy()        max_val.append(max(weight))        bn_val.extend(weight)    bn_val = np.abs(bn_val)    max_val = np.abs(max_val)    bn_val = sorted(bn_val)    max_val = sorted(max_val)    plt.hist(bn_val, bins=101, align="mid", log=True, range=(0, 1.0))    if save_path is not None:        if os.path.isfile(save_path):            os.remove(save_path)        plt.savefig(save_path)    return bn_val, max_valdef channel_prune(ori_model, example_inputs, output_transform, pruned_prob=0.3, thres=None, rules=1):    model = copy.deepcopy(ori_model)    model.cpu().eval()    prunable_module_type = (nn.BatchNorm2d)    ignore_idx = [] #[230, 260, 290]    prunable_modules = []    for i, m in enumerate(model.modules()):        if i in ignore_idx:            continue        if isinstance(m, nn.Upsample):            continue        if isinstance(m, prunable_module_type):            prunable_modules.append(m)    ori_size = tp.utils.count_params(model)    DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs,                                               output_transform=output_transform)    bn_val, max_val = bn_analyze(prunable_modules, "render_img/before_pruning.jpg")    if thres is None:        thres_pos = int(pruned_prob * len(bn_val))        thres_pos = min(thres_pos, len(bn_val)-1)        thres_pos = max(thres_pos, 0)        thres = bn_val[thres_pos]    print("Min val is %f, Max val is %f, Thres is %f" % (bn_val[0], bn_val[-1], thres))    for layer_to_prune in prunable_modules:        # select a layer        weight = layer_to_prune.weight.data.detach().cpu().numpy()        if isinstance(layer_to_prune, nn.Conv2d):            if layer_to_prune.groups > 1:                prune_fn = tp.prune_group_conv            else:                prune_fn = tp.prune_conv            L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3))        elif isinstance(layer_to_prune, nn.BatchNorm2d):            prune_fn = tp.prune_batchnorm            L1_norm = np.abs(weight)        pos = np.array([i for i in range(len(L1_norm))])        pruned_idx_mask = L1_norm < thres        prun_index = pos[pruned_idx_mask].tolist()        if rules != 1:            prune_channel_nums = len(L1_norm) - max(rules, int((len(L1_norm) - pruned_idx_mask.sum())/rules + 0.5)*rules)            _, index = torch.topk(torch.tensor(L1_norm), prune_channel_nums, largest=False)            prun_index = index.numpy().tolist()                    if len(prun_index) == len(L1_norm):            del prun_index[np.argmax(L1_norm)]        plan = DG.get_pruning_plan(layer_to_prune, prune_fn, prun_index)        plan.exec()    bn_analyze(prunable_modules, "render_img/after_pruning.jpg")    with torch.no_grad():        out = model(example_inputs)        if output_transform:            out = output_transform(out)        print("  Params: %s => %s" % (ori_size, tp.utils.count_params(model)))        if isinstance(out, (list, tuple)):            for o in out:                print("  Output: ", o.shape)        else:            print("  Output: ", out.shape)        print("------------------------------------------------------\n")    return modelif __name__ == "__main__":    parser = argparse.ArgumentParser()    parser.add_argument("--cfg", default="models/yolov5s_voc.yaml", type=str, help="*.cfg path")    parser.add_argument("--weights", default="runs/exp7_sl-2e-3-yolov5s/weights/last.pt", type=str, help="*.data path")    parser.add_argument("--save-dir", default="runs/exp7_sl-2e-3-yolov5s/weights", type=str, help="*.data path")    parser.add_argument("-r", "--rate", default=1, type=int, help="通道数为rate的倍数")    parser.add_argument("-p", "--prob", default=0.5, type=float, help="pruning prob")    parser.add_argument("-t", "--thres", default=0, type=float, help="pruning thres")    opt = parser.parse_args()    cfg = opt.cfg    weights = opt.weights    save_dir = opt.save_dir    device = torch.device("cpu")    model, names = load_model(cfg, weights)    example_inputs = torch.zeros((1, 3, 64, 64), dtype=torch.float32).to()    output_transform = None    # for prob in [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:    if opt.thres != 0:        thres = opt.thres        prob = "p.auto"    else:        thres = None        prob = opt.prob    pruned_model = channel_prune(model, example_inputs=example_inputs,                                 output_transform=output_transform, pruned_prob=prob, thres=thres,rules=opt.rate)    pruned_model.model[-1].export = False    pruned_model.names = names    save_path = os.path.join(save_dir, "pruned_"+str(prob).split(".")[-1] + ".pt")    print(pruned_model)    torch.save({"model": pruned_model.module if hasattr(pruned_model, "module") else pruned_model}, save_path)

可以按比例剪枝, 如剪枝比例0.5:

python prune.py --cfg models/yolov5s_voc.yaml --weights runs/exp7_sl-2e-3-yolov5s/weights/last.pt --save-dir runs/exp7_sl-2e-3-yolov5s/weights/ --prob 0.5

还可以按权重大小剪枝,比如小于0.01权重的通道剪:

python prune.py --cfg models/yolov5s_voc.yaml --weights runs/exp7_sl-2e-3-yolov5s/weights/last.pt --save-dir runs/exp7_sl-2e-3-yolov5s/weights/ --thres 0.01

往往通道是8的倍数时,神经网络推理较快:

python prune.py --cfg models/yolov5s_voc.yaml --weights runs/exp7_sl-2e-3-yolov5s/weights/last.pt --save-dir runs/exp7_sl-2e-3-yolov5s/weights/ --thres 0.01 --rate 8

执行剪枝后,模型将会变小。

微调

剪枝后,模型性能会下降,此时我们需要再微调剪枝后的模型,其训练过程与剪枝前训练方式一致。一般情况下,可以接近剪枝前的性能。

结语

通过剪枝可以在精度损失较小的情况下,加快模型的推理速度,在我们需要做实时分析的任务中非常有用。

模型下载

轻量级二维码检测模型:模型下载

关键词:

相关新闻

资讯

与你息息相关!这些新规即将施行
与你息息相关!这些新规即将施行

7月起,一批新规将陆续施行涉及出行、医保、快递等与......更多>

市科技局选派干部参加甘肃省“三区”科技人才培训-世界即时看
市科技局选派干部参加甘肃省“三区”科技人才培训-世界即时看

6月26日,记者从市科技局获悉,根据省科技厅的统一部......更多>

今日聚焦!减亏容易增收难:新能源短期功率预测的价值
今日聚焦!减亏容易增收难:新能源短期功率预测的价值

减亏容易增收难:新能源短期功率预测的价值减亏容易增......更多>

电动自行车头盔新国标7月实施:没这些标志上路有危险!-全球聚焦
电动自行车头盔新国标7月实施:没这些标志上路有危险!-全球聚焦

据了解,为降低道路交通事故中驾乘人员头部伤害,国家......更多>

央视新闻报道!唐山LNG储罐项目接收站一阶段工程顺利投产
央视新闻报道!唐山LNG储罐项目接收站一阶段工程顺利投产

6月21日,河北建投唐山LNG项目一阶段工程举行投产仪式......更多>

样本量的计算公式中p是什么意思(样本量的计算方法) 天天快资讯
样本量的计算公式中p是什么意思(样本量的计算方法) 天天快资讯

来为大家解答以上的问题。样本量的计算公式中p是什么......更多>

斩鬼的姬武者
斩鬼的姬武者

扶桑皇国,北渡岛,在这边陲的边陲,一个身穿兽皮大衣......更多>

每日快讯!埃利奥特:文班问我们如何在客场有充足的睡眠 多特别的年轻人啊
每日快讯!埃利奥特:文班问我们如何在客场有充足的睡眠 多特别的年轻人啊

近日,马刺名宿埃利奥特在《DanPatrickShow》节目中谈......更多>

每日播报!TechInsights:预计6.18期间智能手机的销量为1340万部 同比下降7%
每日播报!TechInsights:预计6.18期间智能手机的销量为1340万部 同比下降7%

智通财经APP获悉,TechInsights称,智能手机是6 18大......更多>

观点:美方以涉芬太尼为由起诉中国企业,中方坚决反对,强烈谴责
观点:美方以涉芬太尼为由起诉中国企业,中方坚决反对,强烈谴责

美方采用“钓鱼执法”方式,非法获取所谓“证据”,起......更多>

关注

全球热资讯!梅州蕉岭:开展“6·26”国际禁毒日宣传活动,持续掀起禁毒宣传教育热潮
全球热资讯!梅州蕉岭:开展“6·26”国际禁毒日宣传活动,持续掀起禁毒宣传教育热潮
为广泛动员社会力量参与禁毒斗争,持续掀起禁毒宣传教... 更多>
全球热资讯!梅州蕉岭:开展“6·26”国际禁毒日宣传活动,持续掀起禁毒宣传教育热潮
为广泛动员社会力量参与禁毒斗争,持续掀起禁毒宣传教... 更多>
计算机行业2023年中期投资策略:迎接AI行情从供给迈向应用的拐点
计算机行业2023年中期投资策略:迎接AI行情从供给迈向... 更多>
世界快播:汽车“飞”上天?广汽做到了!
广汽研究院院长吴坚表示,广汽飞行汽车GOVE采用分离式... 更多>
中信股份:金属香港为金属国际分别向三家银行申请的授信额度提供连带责任保证
中信股份(00267)发布公告,为满足金属国际经营资金需... 更多>
资讯:亚马逊云科技顾凡:中国生成式AI大模型,不会一家通吃
亚马逊云表示,中国生成式AI基础模型,不会是一家通吃... 更多>
热消息:端午节期间江西省九江市12315机构共接收投诉举报咨询各类诉求681件
2023年端午节期间(6月22日至6月24日),江西省九江市... 更多>
世界短讯!中国代表在人权理事会就香港国安法问题阐明严正立场
当地时间6月26日,联合国人权理事会第53届会议举行与... 更多>
电磁流量计的工作原理图_电磁流量计的工作原理
1、电磁流量计是根据法拉第电磁感应定律进行流量测量... 更多>
省会研发,县域制造,“借巢孵蛋”加快科研成果转化
“创新突破,产业突围。”这是6月25日中南大学与津市... 更多>
一文看亮点!2023夏季达沃斯27日至29日举办_每日热门
据世界经济论坛消息,2023年世界经济论坛新领军者年会... 更多>
easyMarkets:欧佩克组织认为到2045年全球石油需求将增至1.1亿桶/日|天天最资讯
周一,石油输出国组织(OC)秘书长HaithaAlGhai表示,... 更多>
湖北发布“民企融资十条” 今年将新增民营企业贷款2500亿元|全球热资讯
湖北发布“民企融资十条”今年将新增民营企业贷款2500... 更多>