动手学深度学习(Pytorch版)代码实践 -卷积神经网络-29残差网络ResNet

29残差网络ResNet

在这里插入图片描述

import torch  
from torch import nn  
from torch.nn import functional as F 
import liliPytorch as lp  
import matplotlib.pyplot as plt

# 定义一个继承自nn.Module的残差块类
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        # 第一个卷积层,使用3x3的卷积核,填充为1,步幅为指定值
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        # 第二个卷积层,使用3x3的卷积核,填充为1
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        
        # 可选的1x1卷积层,用于匹配输入输出通道数和步幅
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        
        # 批量归一化层
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        # 为什么需要两个不同的批量归一化层?
        # 1.不同的位置,不同的输入特征
        # 2.独立的参数和统计数据
    
    def forward(self, X):
        # 先通过第一个卷积层、批量归一化层和ReLU激活函数
        Y = F.relu(self.bn1(self.conv1(X)))
        # 然后通过第二个卷积层和批量归一化层
        Y = self.bn2(self.conv2(Y))
        # 如果定义了conv3,则通过conv3调整X
        if self.conv3:
            X = self.conv3(X)
        # 将输入X加到输出Y上实现残差连接
        Y += X
        # 通过ReLU激活函数
        return F.relu(Y)

# 创建一个包含输入和输出形状一致的残差块实例,并测试其输出形状
# blk = Residual(3, 3)
# X = torch.rand(4, 3, 6, 6)
# Y = blk(X)
# print(Y.shape)  # 预期输出形状:torch.Size([4, 3, 6, 6])

# 创建一个包含1x1卷积和步幅为2的残差块实例,并测试其输出形状
# blk = Residual(3, 6, use_1x1conv=True, strides=2)
# print(blk(X).shape)  # 预期输出形状:torch.Size([4, 6, 3, 3])

# 定义一个包含初始卷积层、批量归一化层、ReLU激活函数和最大池化层的顺序容器
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

# 定义一个函数,用于创建由多个残差块组成的模块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        # 如果是第一个残差块且不是第一个模块,则使用1x1卷积和步幅为2
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

# 创建由残差块组成的各个模块
# *符号有多种用途,但在函数调用时,*符号主要用于将列表或元组解包。
# *resnet_block()的作用是将列表中的元素逐个传递给nn.Sequential
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

# 创建整个ResNet模型
net = nn.Sequential(
    b1, 
    b2, 
    b3, 
    b4, 
    b5,
    nn.AdaptiveAvgPool2d((1, 1)),  # 自适应平均池化层
    nn.Flatten(),  # 展平层
    nn.Linear(512, 176)  # 全连接层,输出10类
)

# 测试整个网络的输出形状
X = torch.rand(size=(1, 1, 96, 96))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 128, 12, 12])
# Sequential output shape:         torch.Size([1, 256, 6, 6])
# Sequential output shape:         torch.Size([1, 512, 3, 3])
# AdaptiveAvgPool2d output shape:  torch.Size([1, 512, 1, 1])
# Flatten output shape:    torch.Size([1, 512])
# Linear output shape:     torch.Size([1, 10])

# 设置训练参数
lr, num_epochs, batch_size = 0.05, 10, 256
# 加载训练和测试数据
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size, resize=96)
# 训练模型
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# 显示训练结果
plt.show()

# loss 0.009, train acc 0.998, test acc 0.920
# 2306.3 examples/sec on cuda:0

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/755684.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI副业赚钱攻略:掌握数字时代的机会

前言 最近国产大模型纷纷上线,飞入寻常百姓家。AI副业正成为许多人寻找额外收入的途径。无论您是想提高家庭收入还是寻求职业发展,这里有一个变现,帮助您掌握AI兼职副业的机会。 1. 了解AI的基础知识 在开始之前,了解AI的基础…

【笔记】Spring Cloud Gateway 实现 gRPC 代理

Spring Cloud Gateway 在 3.1.x 版本中增加了针对 gRPC 的网关代理功能支持,本片文章描述一下如何实现相关支持.本文主要基于 Spring Cloud Gateway 的 官方文档 进行一个实践练习。有兴趣的可以翻看官方文档。 由于 Grpc 是基于 HTTP2 协议进行传输的,因此 Srping …

zabbix监控进阶:如何分时段设置不同告警阈值(多阈值告警)

作者 乐维社区(forum.lwops.cn)乐乐 在生产环境中,企业的业务系统状态并不是一成不变的。在业务高峰时段,如节假日、促销活动或特定时间段,系统负载和用户访问量会大幅增加,此时可能需要设置更高的告警阈值…

vscode 使用正则将/deep/ 替换成 :deep()

在VSCODE编辑器的SEARCH中按上图书写即可,正则表达式如下:(\/deep\/)(.*?)(?\{) 替换操作如下::deep($2) 如果有用,号隔开的用:(\/deep\/)(.*?)(?,)替换操作如下::deep($2) 即可实现快速替换所有/deep/写法; 同理…

Cyuyanzhong的内存函数

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、memcpy函数的使用与模拟实现二、memmove函数的使用和模拟实现三、memset函数与memcmp函数的使用(一)、memset函数(内存块…

一文速览Google的Gemma:从gemma1到gemma2

前言 如此文《七月论文审稿GPT第3.2版和第3.5版:通过paper-review数据集分别微调Mistral、gemma》所讲 Google作为曾经的AI老大,我司自然紧密关注,所以当Google总算开源了一个gemma 7b,作为有技术追求、技术信仰的我司&#xff0…

大模型ReAct:思考与工具协同完成复杂任务推理

ReAct: Synergizing Reasoning and Acting in Language Models Github:https://github.com/ysymyth/ReAct 一、动机 人类的认知通常具备一定的自我调节(self-regulation)和策略制定(strategization)的能力&#xff0…

福昕阅读器再打开PDF文件时,总是单页显示,如何设置打开后就自动显示单页连续的模式呢

希望默认进入连续模式 设置方法 参考链接 如何设置使福昕阅读器每次启动时不是阅读模式 每次启动后都要退出阅读模式 麻烦_百度知道 (baidu.com)https://zhidao.baidu.com/question/346796551.html#:~:text%E5%9C%A8%E3%80%90%E5%B7%A5%E5%85%B7%E3%80%91%E9%87%8C%E6%9C%89%E…

Springboot下使用Redis管道(pipeline)进行批量操作

之前有业务场景需要批量插入数据到Redis中,做的过程中也有一些感悟,因此记录下来,以防忘记。下面的内容会涉及到 分别使用for、管道处理批量操作,比较其所花费时间。 分别使用RedisCallback、SessionCallback进行Redis pipeline …

从零开始学Spring Boot系列-集成Spring Security实现用户认证与授权

在Web应用程序中,安全性是一个至关重要的方面。Spring Security是Spring框架的一个子项目,用于提供安全访问控制的功能。通过集成Spring Security,我们可以轻松实现用户认证、授权、加密、会话管理等安全功能。本篇文章将指导大家从零开始&am…

昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com) 基于MindSpore通过GPT实现情感分类 %%capture captured_output # 实验环境已经预装了mindspore2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号 !pip uninsta…

Mysql常用SQL:日期转换成周_DAYOFWEEK(date)

有时候需要将查询出来的日期转换成周几,Mysql本身语法就是支持这种转换的,就是DAYOFWEEK()函数 语法格式:DAYOFWEEK(date) (date:可以是指定的具体日期( 如2024-06-29 ),也可以是日期…

一个项目学习IOS开发---创建一个IOS开发项目

前提: 由于IOS开发只能在MacOS上开发,所以黑苹果或者购买一台MacBook Pro是每个IOS开发者必备的技能或者工具之一 Swift开发工具一般使用MacOS提供的Xcode开发工具 首先Mac Store下载Xcode工具 安装之后打开会提醒你安装IOS的SDK,安装好之…

媒体宣发套餐的概述及推广方法-华媒舍

在今天的数字化时代,对于产品和服务的宣传已经变得不可或缺。媒体宣发套餐作为一种高效的宣传方式,在帮助企业塑造品牌形象、扩大影响力方面扮演着重要角色。本文将揭秘媒体宣发套餐,为您呈现一条通往成功的路。 1. 媒体宣发套餐的概述 媒体…

使用Tailwindcss之后,vxe-table表头排序箭头高亮消失的问题解决

环境 vue2.7.8 vxe-table3.5.9 tailwindcss/postcss7-compat2.2.17 postcss7.0.39 autoprefixer9.8.8 问题 vxe-table 表格表头 th 的排序箭头在开启正序或逆序排序时,会显示蓝色高亮来提示用户表格数据处在排序情况下。在项目开启运行了tailwindcss之后&#xff0…

Kafka入门-基础概念及参数

一、Kafka术语 Kafka属于分布式的消息引擎系统,它的主要功能是提供一套完备的消息发布与订阅解决方案。可以为每个业务、每个应用甚至是每类数据都创建专属的主题。 Kafka的服务器端由被称为Broker的服务进程构成,即一个Kafka集群由多个Broker组成&#…

dledger原理源码分析系列(二)-心跳

简介 dledger是openmessaging的一个组件, raft算法实现,用于分布式日志,本系列分析dledger如何实现raft概念,以及dledger在rocketmq的应用 本系列使用dledger v0.40 本文分析dledger的心跳 关键词 Raft Openmessaging 心跳/…

Android14之RRO资源文件替换策略(二百二十一)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒体系统工程师系列【原创干货持续更新中……】🚀 优质视频课程:AAOS车载系统+AOSP…

.NET 一款利用内核驱动关闭AV/EDR的工具

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

微服务 | Springboot整合GateWay+Nacos实现动态路由

1、简介 路由转发 执行过滤器链。 ​ 网关,旨在为微服务架构提供一种简单有效的统一的API路由管理方式。同时,基于Filter链的方式提供了网关的基本功能,比如:鉴权、流量控制、熔断、路径重写、黑白名单、日志监控等。 基本功能…