设为首页收藏本站
网站公告 | 这是第一条公告
     

 找回密码
 立即注册
缓存时间09 现在时间09 缓存数据 我们所有的努力所有的奋斗,都是为了拥有一个美好的未来。和遇见更好的自己。请把努力当成一种习惯,而不是三分钟热度。每一个你羡慕的收获,都是努力用心拼来的。早安!

我们所有的努力所有的奋斗,都是为了拥有一个美好的未来。和遇见更好的自己。请把努力当成一种习惯,而不是三分钟热度。每一个你羡慕的收获,都是努力用心拼来的。早安!

查看: 436|回复: 2

利用Pytorch实现获取特征图的方法详解

[复制链接]

  离线 

TA的专栏

  • 打卡等级:热心大叔
  • 打卡总天数:205
  • 打卡月天数:0
  • 打卡总奖励:3099
  • 最近打卡:2023-08-27 09:30:00
等级头衔

等級:晓枫资讯-上等兵

在线时间
0 小时

积分成就
威望
0
贡献
388
主题
360
精华
0
金钱
4237
积分
773
注册时间
2022-12-25
最后登录
2025-5-28

发表于 2023-2-10 22:04:33 | 显示全部楼层 |阅读模式
简单加载官方预训练模型

torchvision.models预定义了很多公开的模型结构
如果pretrained参数设置为False,那么仅仅设定模型结构;如果设置为True,那么会启动一个下载流程,下载预训练参数
如果只想调用模型,不想训练,那么设置model.eval()和model.requires_grad_(False)
想查看模型参数可以使用modules和named_modules,其中named_modules是一个长度为2的tuple,第一个变量是name,第二个变量是module本身。
  1. # -*- coding: utf-8 -*-
  2. from torch import nn
  3. from torchvision import models

  4. # load model. If pretrained is True, there will be a downloading process
  5. model = models.vgg19(pretrained=True)
  6. model.eval()
  7. model.requires_grad_(False)

  8. # get model component
  9. features = model.features
  10. modules = features.modules()
  11. named_modules = features.named_modules()

  12. # print modules
  13. for module in modules:
  14.     if isinstance(module, nn.Conv2d):
  15.         weight = module.weight
  16.         bias = module.bias
  17.         print(module, weight.shape, bias.shape,
  18.               weight.requires_grad, bias.requires_grad)
  19.     elif isinstance(module, nn.ReLU):
  20.         print(module)

  21. print()
  22. for named_module in named_modules:
  23.     name = named_module[0]
  24.     module = named_module[1]
  25.     if isinstance(module, nn.Conv2d):
  26.         weight = module.weight
  27.         bias = module.bias
  28.         print(name, module, weight.shape, bias.shape,
  29.               weight.requires_grad, bias.requires_grad)
  30.     elif isinstance(module, nn.ReLU):
  31.         print(name, module)
复制代码
图片预处理

使用opencv和pil读图都可以使用transforms.ToTensor()把原本[H, W, 3]的数据转成[3, H, W]的tensor。但opencv要注意把数据改成RGB顺序。
vgg系列模型需要做normalization,建议配合torchvision.transforms来实现。
mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
参考:https://pytorch.org/hub/pytorch_vision_vgg/
  1. # -*- coding: utf-8 -*-
  2. from PIL import Image
  3. import cv2
  4. import torch
  5. from torchvision import transforms

  6. # transforms for preprocess
  7. preprocess = transforms.Compose([
  8.     transforms.ToTensor(),
  9.     transforms.Normalize(mean=[0.485, 0.456, 0.406],
  10.                          std=[0.229, 0.224, 0.225])
  11. ])

  12. # load image using cv2
  13. image_cv2 = cv2.imread('lena_std.bmp')
  14. image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
  15. image_cv2 = preprocess(image_cv2)

  16. # load image using pil
  17. image_pil = Image.open('lena_std.bmp')
  18. image_pil = preprocess(image_pil)

  19. # check whether image_cv2 and image_pil are same
  20. print(torch.all(image_cv2 == image_pil))
  21. print(image_cv2.shape, image_pil.shape)
复制代码
提取单个特征图

如果只提取单层特征图,可以把模型截断,以节省算力和显存消耗。
下面索引之所以有+1是因为pytorch预训练模型里面第一个索引的module总是完整模块结构,第二个才开始子模块。
  1. # -*- coding: utf-8 -*-
  2. from PIL import Image
  3. from torchvision import models
  4. from torchvision import transforms

  5. # load model. If pretrained is True, there will be a downloading process
  6. model = models.vgg19(pretrained=True)
  7. model = model.features[:16 + 1]  # 16 = conv3_4
  8. model.eval()
  9. model.requires_grad_(False)
  10. model.to('cuda')
  11. print(model)

  12. # load and preprocess image
  13. preprocess = transforms.Compose([
  14.     transforms.ToTensor(),
  15.     transforms.Normalize(mean=[0.485, 0.456, 0.406],
  16.                          std=[0.229, 0.224, 0.225]),
  17.     transforms.Resize(size=(224, 224))
  18. ])
  19. image = Image.open('lena_std.bmp')
  20. image = preprocess(image)
  21. inputs = image.unsqueeze(0)  # add batch dimension
  22. inputs = inputs.cuda()

  23. # forward
  24. output = model(inputs)
  25. print(output.shape)
复制代码
提取多个特征图

第一种方式:逐层运行model,如果碰到了需要保存的feature map就存下来。
第二种方式:使用register_forward_hook,使用这种方式需要用一个类把feature map以成员变量的形式缓存下来。
两种方式的运行效率差不多
第一种方式简单直观,但是只能处理类似VGG这种没有跨层连接的网络;第二种方式更加通用。
  1. # -*- coding: utf-8 -*-
  2. from PIL import Image
  3. import torch
  4. from torchvision import models
  5. from torchvision import transforms

  6. # load model. If pretrained is True, there will be a downloading process
  7. model = models.vgg19(pretrained=True)
  8. model = model.features[:16 + 1]  # 16 = conv3_4
  9. model.eval()
  10. model.requires_grad_(False)
  11. model.to('cuda')

  12. # check module name
  13. for named_module in model.named_modules():
  14.     name = named_module[0]
  15.     module = named_module[1]
  16.     print('-------- %s --------' % name)
  17.     print(module)
  18.     print()

  19. # load and preprocess image
  20. preprocess = transforms.Compose([
  21.     transforms.ToTensor(),
  22.     transforms.Normalize(mean=[0.485, 0.456, 0.406],
  23.                          std=[0.229, 0.224, 0.225]),
  24.     transforms.Resize(size=(224, 224))
  25. ])
  26. image = Image.open('lena_std.bmp')
  27. image = preprocess(image)
  28. inputs = image.unsqueeze(0)  # add batch dimension
  29. inputs = inputs.cuda()

  30. # forward - 1
  31. layers = [2, 7, 8, 9, 16]
  32. layers = sorted(set(layers))
  33. feature_maps = {}
  34. feature = inputs
  35. for i in range(max(layers) + 1):
  36.     feature = model[i](feature)
  37.     if i in layers:
  38.         feature_maps[i] = feature
  39. for key in feature_maps:
  40.     print(key, feature_maps.get(key).shape)


  41. # forward - 2
  42. class FeatureHook:
  43.     def __init__(self, module):
  44.         self.inputs = None
  45.         self.output = None
  46.         self.hook = module.register_forward_hook(self.get_features)

  47.     def get_features(self, module, inputs, output):
  48.         self.inputs = inputs
  49.         self.output = output


  50. layer_names = ['2', '7', '8', '9', '16']
  51. hook_modules = []
  52. for named_module in model.named_modules():
  53.     name = named_module[0]
  54.     module = named_module[1]
  55.     if name in layer_names:
  56.         hook_modules.append(module)

  57. hooks = [FeatureHook(module) for module in hook_modules]
  58. output = model(inputs)
  59. features = [hook.output for hook in hooks]
  60. for feature in features:
  61.     print(feature.shape)

  62. # check correctness
  63. for i, layer in enumerate(layers):
  64.     feature1 = feature_maps.get(layer)
  65.     feature2 = features[i]
  66.     print(torch.all(feature1 == feature2))
复制代码
使用第二种方式(register_forward_hook),resnet特征图也可以顺利拿到。
而由于resnet的model已经不可以用model的形式索引,所以无法使用第一种方式。
  1. # -*- coding: utf-8 -*-
  2. from PIL import Image
  3. from torchvision import models
  4. from torchvision import transforms

  5. # load model. If pretrained is True, there will be a downloading process
  6. model = models.resnet18(pretrained=True)
  7. model.eval()
  8. model.requires_grad_(False)
  9. model.to('cuda')

  10. # check module name
  11. for named_module in model.named_modules():
  12.     name = named_module[0]
  13.     module = named_module[1]
  14.     print('-------- %s --------' % name)
  15.     print(module)
  16.     print()

  17. # load and preprocess image
  18. preprocess = transforms.Compose([
  19.     transforms.ToTensor(),
  20.     transforms.Normalize(mean=[0.485, 0.456, 0.406],
  21.                          std=[0.229, 0.224, 0.225]),
  22.     transforms.Resize(size=(224, 224))
  23. ])
  24. image = Image.open('lena_std.bmp')
  25. image = preprocess(image)
  26. inputs = image.unsqueeze(0)  # add batch dimension
  27. inputs = inputs.cuda()


  28. class FeatureHook:
  29.     def __init__(self, module):
  30.         self.inputs = None
  31.         self.output = None
  32.         self.hook = module.register_forward_hook(self.get_features)

  33.     def get_features(self, module, inputs, output):
  34.         self.inputs = inputs
  35.         self.output = output


  36. layer_names = [
  37.     'conv1',
  38.     'layer1.0.relu',
  39.     'layer2.0.conv1'
  40. ]

  41. hook_modules = []
  42. for named_module in model.named_modules():
  43.     name = named_module[0]
  44.     module = named_module[1]
  45.     if name in layer_names:
  46.         hook_modules.append(module)

  47. hooks = [FeatureHook(module) for module in hook_modules]
  48. output = model(inputs)
  49. features = [hook.output for hook in hooks]
  50. for feature in features:
  51.     print(feature.shape)
复制代码
问题来了,resnet这种类型的网络结构怎么截断?
使用如下命令就可以,print查看需要截断到哪里,然后用nn.Sequential重组即可。
需注意重组后网络的module_name会发生变化。
  1. print(list(model.children())
  2. model = torch.nn.Sequential(*list(model.children())[:6])
复制代码
以上就是利用Pytorch实现获取特征图的方法详解的详细内容,更多关于Pytorch获取特征图的资料请关注晓枫资讯其它相关文章!

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
晓枫资讯-科技资讯社区-免责声明
免责声明:以上内容为本网站转自其它媒体,相关信息仅为传递更多信息之目的,不代表本网观点,亦不代表本网站赞同其观点或证实其内容的真实性。
      1、注册用户在本社区发表、转载的任何作品仅代表其个人观点,不代表本社区认同其观点。
      2、管理员及版主有权在不事先通知或不经作者准许的情况下删除其在本社区所发表的文章。
      3、本社区的文章部分内容可能来源于网络,仅供大家学习与参考,如有侵权,举报反馈:点击这里给我发消息进行删除处理。
      4、本社区一切资源不代表本站立场,并不代表本站赞同其观点和对其真实性负责。
      5、以上声明内容的最终解释权归《晓枫资讯-科技资讯社区》所有。
http://bbs.yzwlo.com 晓枫资讯--游戏IT新闻资讯~~~

  离线 

TA的专栏

等级头衔

等級:晓枫资讯-列兵

在线时间
0 小时

积分成就
威望
0
贡献
0
主题
0
精华
0
金钱
16
积分
12
注册时间
2022-12-29
最后登录
2022-12-29

发表于 2023-2-10 23:04:54 | 显示全部楼层
感谢楼主分享~~~~~
http://bbs.yzwlo.com 晓枫资讯--游戏IT新闻资讯~~~

  离线 

TA的专栏

  • 打卡等级:即来则安
  • 打卡总天数:25
  • 打卡月天数:0
  • 打卡总奖励:289
  • 最近打卡:2025-03-31 22:20:37
等级头衔

等級:晓枫资讯-列兵

在线时间
0 小时

积分成就
威望
0
贡献
0
主题
0
精华
0
金钱
337
积分
56
注册时间
2023-1-7
最后登录
2025-3-31

发表于 3 天前 | 显示全部楼层
感谢楼主,顶。
http://bbs.yzwlo.com 晓枫资讯--游戏IT新闻资讯~~~
严禁发布广告,淫秽、色情、赌博、暴力、凶杀、恐怖、间谍及其他违反国家法律法规的内容。!晓枫资讯-社区
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

1楼
2楼
3楼

手机版|晓枫资讯--科技资讯社区 本站已运行

CopyRight © 2022-2025 晓枫资讯--科技资讯社区 ( BBS.yzwlo.com ) . All Rights Reserved .

晓枫资讯--科技资讯社区

本站内容由用户自主分享和转载自互联网,转载目的在于传递更多信息,并不代表本网赞同其观点和对其真实性负责。

如有侵权、违反国家法律政策行为,请联系我们,我们会第一时间及时清除和处理! 举报反馈邮箱:点击这里给我发消息

Powered by Discuz! X3.5

快速回复 返回顶部 返回列表