
离线 TA的专栏
- 打卡等级:热心大叔
- 打卡总天数:224
- 打卡月天数:0
- 打卡总奖励:3517
- 最近打卡:2025-06-30 06:53:44
|
加载模型并查看网络
加载模型,以vgg19为例。
打开终端
- > python
- Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
- (AMD64)] on win32
- Type "help", "copyright", "credits" or "license" for more information.
- >>> from torchvision import models
- >>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
- >>> print(model.model)
复制代码结果: - VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace) (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (17): ReLU(inplace) (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace) (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (24): ReLU(inplace) (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (26): ReLU(inplace) (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace) (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (31): ReLU(inplace) (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (33): ReLU(inplace) (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (35): ReLU(inplace) (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (classifier): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace) (2): Dropout(p=0.5) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace) (5): Dropout(p=0.5) (6): Linear(in_features=4096, out_features=1000, bias=True) ))
复制代码注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。
神经网络_模型的保存,模型的加载
模型的保存(torch.save)
方式1(模型结构+模型参数)
参数:保存位置 - # 创建模型
- vgg16 = torchvision.models.vgg16(pretrained=False)
- # 保存方式1——模型结构+模型参数
- torch.save(vgg16, "vgg16_method1.pth")
复制代码 方式2(模型参数)- # 保存方式2 模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
- torch.save(vgg16.state_dict(), "vgg16_method2.pth")
复制代码 模型的加载(torch.load)
对应保存方式1
参数:模型路径 - # 方式1 --》 保存方式1
- model1 = torch.load("vgg16_method1.pth")
复制代码 对应保存方式2- vgg16.load_state_dict("vgg16_method2.pth")
复制代码输出为字典形式。若要回复网络,采用以下形式: - model2 = torch.load("vgg16_method2.pth") #输出是字典形式
- # 恢复网络结构
- vgg16 = torchvision.models.vgg16(pretrained=False)
- vgg16.load_state_dict(model2)
复制代码 方式1存储,加载时需注意事项
新建自己的网络: - class test(nn.Module):
- def __init__(self):
- super(lh, self).__init__()
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
- def forward(self, x):
- x = self.conv1(x)
- return x
复制代码保存自己的网络: - Test = test()
- # 保存自己定义的网络
- torch.save(Test, "Test_method1.pth")
复制代码加载自己的网络: - model3 = torch.load("Test_method1.pth")
复制代码会报错!!!!!!
解决办法(需要注意):
将定义的网络复制到加载的python文件中: - class test(nn.Module):
- def __init__(self):
- super(test, self).__init__()
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
- def forward(self, x):
- x = self.conv1(x)
- return x
- model3 = torch.load("Test_method1.pth")
复制代码以上为个人经验,希望能给大家一个参考,也希望大家多多支持晓枫资讯。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |
晓枫资讯-科技资讯社区-免责声明
免责声明:以上内容为本网站转自其它媒体,相关信息仅为传递更多信息之目的,不代表本网观点,亦不代表本网站赞同其观点或证实其内容的真实性。
1、注册用户在本社区发表、转载的任何作品仅代表其个人观点,不代表本社区认同其观点。
2、管理员及版主有权在不事先通知或不经作者准许的情况下删除其在本社区所发表的文章。
3、本社区的文章部分内容可能来源于网络,仅供大家学习与参考,如有侵权,举报反馈:  进行删除处理。
4、本社区一切资源不代表本站立场,并不代表本站赞同其观点和对其真实性负责。
5、以上声明内容的最终解释权归《晓枫资讯-科技资讯社区》所有。
|