【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)

在这里插入图片描述

  • 博主简介:努力学习的22级计算机科学与技术本科生一枚🌸
  • 博主主页: @Yaoyao2024
  • 往期回顾: 【机器学习】有监督学习·由浅入深讲解分类算法·Fisher算法讲解
  • 每日一言🌼: 今天不想跑,所以才去跑,这才是长距离者的思维。
    ——村上春树

本文是对Google DeepMind 团队2015年发表的空间变换网络STN的详细讲解,作为初学者也是参考了很多博客,都在本文末尾给出,感谢前辈们的努力。

空间变换网络(Spatial Transformer Networks,简称STN)是一种深度学习模型,旨在增强网络对几何变换的适应能力。STN是由Max Jaderberg等人在2015年提出的,其核心思想是在传统的卷积神经网络(CNN)中嵌入一个可学习的模块,该模块能够显式地对输入图像进行空间变换,从而使得网络能够对输入图像的几何变形具有更好的适应性。STN的引入使得网络能够自动进行图像的校正,例如旋转、缩放、剪切等,这在很多视觉任务中是非常有用的,如图像识别、目标检测和图像分割等。

一、为什么提出(Why)

  1. 一个理想中的模型:我们希望鲁棒的图像处理模型具有空间不变性,当目标发生某种转化后,模型依然能给出同样的正确的结果

  2. 什么是空间不变性:举例来说,如下图所示,假设一个模型能准确把左图中的人物分类为凉宫春日,当这个目标做了放大、旋转、平移后,模型仍然能够正确分类,我们就说这个模型在这个任务上具有尺度不变性,旋转不变性,平移不变性

    在这里插入图片描述

  3. CNN在这方面的能力是不足的:maxpooling的机制给了CNN一点点这样的能力,当目标在池化单元内任意变换的话,激活的值可能是相同的,这就带来了一点点的不变性。但是池化单元一般都很小(一般是2*2),只有在深层的时候特征被处理成很小的feature map的时候这种情况才会发生

  4. Spatial Transformer:本文提出的空间变换网络STN(Spatial Transformer Networks)STN可以使模型学习平移、缩放、旋转和更通用的扭曲的不变性。(二维空间变换网络)

二、STN是什么(What)

  1. STN对feature map(包括输入图像)进行空间变换,输出一张新的图像。
  2. 我们希望STN对feature map进行变换后能把图像纠正到成理想的图像,然后丢进NN去识别,举例来说,如下图所示,输入模型的图像可能是摆着各种姿势,摆在不同位置的凉宫春日,我们希望STN把它纠正到图像的正中央,放大,占满整个屏幕,然后再丢进CNN去识别。
  3. 这个网络可以作为单独的模块,可以在CNN的任何地方插入(即插即用),所以STN的输入不止是输入图像,可以是CNN中间层的feature map
    在这里插入图片描述

三、STN是怎么做的(How)

STN可以通过为每个输入样本生成适当的变换来主动对图像(或特征图)进行空间变换。然后在整个特征图上(非局部)执行变换,并且可以包括缩放、裁剪、旋转以及非刚性变形。这使得包含空间变换器的网络不仅可以选择图像中最相关(注意力)的区域,还可以将这些区域转换为规范的预期姿势,以简化后续层中的推理。

在这里插入图片描述

如上图所示,STN的输入为 U U U,输出为 V V V,因为输入可能是中间层的feature map,所以画成了立方体(多channel),STN主要分为下述三个步骤

  1. 定位网络(Localization Network):这一部分是STN的核心,其任务是学习输入图像的空间变换参数。定位网络可以是任意的网络结构,它接受输入图像,并输出空间变换所需的参数。这些参数定义了一个变换矩阵,用于调整图像的空间位置。(是一个自己定义的网络,它输入 U U U,输出变化参数 Θ \Theta Θ,这个参数用来映射 U U U V V V的坐标关系)。
  2. 网格生成器(Grid Generator):接收定位网络输出的变换参数,并生成一个对应于输出图像的坐标网格。这个坐标网格对应于输入图像中的每一个像素位置。根据 V V V中的坐标点和变化参数 Θ \Theta Θ,计算出 U U U中的坐标点。这里是因为 V V V的大小是自己先定义好的,当然可以得到 V V V的所有坐标点,而填充 V V V 中每个坐标点的像素值的时候,要从 U U U中去取,所以根据 V V V中每个坐标点和变化参数 Θ \Theta Θ进行运算,得到一个坐标。在sampler中就是根据这个坐标去 U U U中找到像素值,这样子来填充 V V V
  3. Sampler:要做的是填充 V V V,根据Grid generator得到的一系列坐标和原图 U U U(因为像素值要从 U U U中取)来填充,因为计算出来的坐标可能为小数,要用另外的方法来填充,比如双线性插值。从输入图像中采样像素来产生变换后的输出图像。这一步骤确保了图像的空间变换是可微分的,从而可以通过反向传播算法进行训练。

下面针对每个模块阐述一下

1、Localisation net

这个模块就是输入 U U U,输出一个变换参数 Θ \Theta Θ,那么这个 Θ \Theta Θ具体是指什么呢?

我们知道线性代数里,图像的平移,旋转和缩放都可以用矩阵运算来做

举例来说,如果想放大图像中的目标,可以这么运算,把(x,y)中的像素值填充到(x’,y’)上去,比如把原来(2,2)上的像素点,填充到(4,4)上去。
[ x ′ y ′ ] = [ 2 0 0 2 ] [ x y ] + [ 0 0 ] \begin{bmatrix}x^{'}\\y^{'}\end{bmatrix}=\begin{bmatrix}2&0\\0&2\end{bmatrix}\begin{bmatrix}x\\y\end{bmatrix}+\begin{bmatrix}0\\0\end{bmatrix} [xy]=[2002][xy]+[00]

如果想旋转图像中的目标,可以这么运算(可以在极坐标系中推出来,证明放到最后的附录)
[ x ′ y ′ ] = [ c o s Θ − s i n Θ s i n Θ c o s Θ ] [ x y ] + [ 0 0 ] \begin{bmatrix}x^{'}\\y^{'}\end{bmatrix}=\begin{bmatrix}cos\Theta&-sin\Theta\\sin\Theta&cos\Theta\end{bmatrix}\begin{bmatrix}x\\y\end{bmatrix}+\begin{bmatrix}0\\0\end{bmatrix} [xy]=[cosΘsinΘsinΘcosΘ][xy]+[00]

这些都是属于仿射变换(affine transformation)

[ x ′ y ′ ] = [ a b c d ] [ x y ] + [ e f ] \begin{bmatrix}x^{^{\prime}}\\y^{^{\prime}}\end{bmatrix}=\begin{bmatrix}a&b\\c&d\end{bmatrix}\begin{bmatrix}x\\y\end{bmatrix}+\begin{bmatrix}e\\f\end{bmatrix} [xy]=[acbd][xy]+[ef]

在仿射变化中,变化参数就是这6个变量, Θ = { a , b , c , d , e , f } (此 Θ 跟上述旋转变化里的角度 Θ 无关) \Theta=\{a,b,c,d,e,f\}\text{(此}\Theta\text{跟上述旋转变化里的角度}\Theta\text{无关)} Θ={a,b,c,d,e,f}(Θ跟上述旋转变化里的角度Θ无关)

这6个变量就是用来映射输入图和输出图之间的坐标点的关系的,我们在第二步grid generator就要根据这个变化参数,来获取原图的坐标点

总结如下:

  1. 功能:定位网络的主要任务是预测空间变换的参数。根据输入图像,这个网络会输出一组参数,这些参数定义了一个空间变换,可以是平移、旋转、缩放等或者更复杂的仿射变换或者非线性变换。
  2. 结构:定位网络通常是一个小型的卷积神经网络全连接网络,其具体结构可以根据任务的复杂度和输入数据的特性来定制。网络的输出大小是固定的,对应于特定变换所需的参数数量。

2、Grid generator

有了第一步的变化参数,这一步是做个矩阵运算,这个运算是 以目标图 V V V的所有坐标点为自变量,以为参数做一个矩阵运算,得到输入图 U U U的坐标点

( x i s y i s ) = Θ ( x i t y i t 1 ) = [ Θ 11 Θ 12 Θ 13 Θ 21 Θ 22 Θ 23 ] ( x i t y i t 1 ) \begin{pmatrix}x_i^s\\y_i^s\end{pmatrix}=\Theta\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix}=\begin{bmatrix}\Theta_{11}&\Theta_{12}&\Theta_{13}\\\Theta_{21}&\Theta_{22}&\Theta_{23}\end{bmatrix}\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix} (xisyis)=Θ xityit1 =[Θ11Θ21Θ12Θ22Θ13Θ23] xityit1

其中 ( x i t , y i ) 记为输出图 V 中的第 i 个坐标点, V 中的长宽可以和 U 不一样,自己定义的,所以这里用 i 来标识第几个坐标点 ( x i s , y i ) {(x_{i}{t},y_{i})} 记 为 输 出 图 V 中 的 第 i 个 坐 标 点 , V 中 的 长 宽 可 以 和 U 不 一 样 , 自 己 定 义 的 , 所 以 这 里 用 i 来 标 识 第 几 个 坐 标 点 {(x_{i}{s},y_{i})} (xit,yi)记为输出图V中的第i个坐标点,V中的长宽可以和U不一样,自己定义的,所以这里用i来标识第几个坐标点(xis,yi)

  • 功能:网格生成器接收定位网络预测的变换参数,并生成一个坐标网格,该网格代表了输入图像中每个像素映射到输出图像中的新位置
  • 原理:对于每个输出图像的像素位置,网格生成器使用变换参数来计算对应的输入图像中的坐标。这一过程通常涉及到矩阵运算,用于实现平移、旋转、缩放等仿射变换。
    在这里插入图片描述

3、Sampler

由于在第二步计算出了V中每个点对应到U的坐标点,在这一步就可以直接根据V的坐标点取得对应到U中坐标点的像素值来进行填充,而不需要经过矩阵运算。需要注意的是,填充并不是直接填充,首先计算出来的坐标可能是小数,要处理一下,其次填充的时候往往要考虑周围的其它像素值。填充根据的公式如下。

V i = ∑ n ∑ m U n m ∗ k ( x i s − m ; ϕ x ) ∗ k ( y i s − n ; ϕ y ) V_i=\sum_n\sum_mU_{nm}*k(x_i^s-m;\phi_x)*k(y_i^s-n;\phi_y) Vi=nmUnmk(xism;ϕx)k(yisn;ϕy)

举例来说,我要填充目标图V中的(2,2)这个点的像素值,经过以下计算得到(1.6,2.4)

( x i s y i s ) = [ Θ 11 Θ 12 Θ 13 Θ 21 Θ 22 Θ 23 ] ( x i t y i t 1 ) ( 1.6 2.4 ) = [ 0 0.5 0.6 1 0 0.4 ] ( 2 2 1 ) \begin{gathered}\begin{pmatrix}x_i^s\\y_i^s\end{pmatrix}=\begin{bmatrix}\Theta_{11}&\Theta_{12}&\Theta_{13}\\\Theta_{21}&\Theta_{22}&\Theta_{23}\end{bmatrix}\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix}\\\begin{pmatrix}1.6\\2.4\end{pmatrix}=\begin{bmatrix}0&0.5&0.6\\1&0&0.4\end{bmatrix}\begin{pmatrix}2\\2\\1\end{pmatrix}\end{gathered} (xisyis)=[Θ11Θ21Θ12Θ22Θ13Θ23] xityit1 (1.62.4)=[010.500.60.4] 221

如果四舍五入后直接填充,则难以做梯度下降。

我们知道做梯度下降时,梯度的表现就是权重发生一点点变化的时候,输出的变化会如何。

如果用四舍五入后直接填充,那么(1.6,2.4)四舍五入后变成(2,2)当 Θ \Theta Θ(我们求导的时候是需要对 Θ \Theta Θ求导的)有一点点变化的时候,(1.6,2.4)可能变成了(1.9,2.1)四舍五入后还是变成(2,2),输出并没有变化,对 Θ \Theta Θ的梯度没有改变,这个时候没法用梯度下降来优化 Θ \Theta Θ

如果采用上面双线性插值的公式来填充,在这个例子里就会考虑(2,2)周围的四个点来填充,这样子,当 Θ \Theta Θ有一点点变化的时,式子的输出就会有变化,因为 ( x i s , y i ) (x_{i}{s},y_{i}) (xis,yi)的变化会引起V的变化。注意下式中U的下标,第一个下标是纵坐标,第二个下标才是横坐标。

V = U 21 ( 1 − 0.6 ) ( 1 − 0.4 ) + U 22 ( 1 − 0.4 ) ( 1 − 0.4 ) + U 31 ( 1 − 0.6 ) ( 1 − 0.6 ) + U 32 ( 1 − 0.4 ) ( 1 − 0.6 ) V=U_{21}(1-0.6)(1-0.4)+U_{22}(1-0.4)(1-0.4)+U_{31}(1-0.6)(1-0.6)+U_{32}(1-0.4)(1-0.6) V=U21(10.6)(10.4)+U22(10.4)(10.4)+U31(10.6)(10.6)+U32(10.4)(10.6)

4、STN小结

简单总结一下,如下图所示
在这里插入图片描述

  1. Localization net根据输入图,计算得到一个Θ
  2. Grid generator根据输出图的坐标点和Θ,计算出输入图的坐标点,举例来说想知道输出图上(2,2)应该填充什么坐标点,则跟Θ 运算,得到(1.6,2.4)
  3. Sampler根据自己定义的填充规则(一般用双线性插值)来填充,比如(2,2)坐标对应到输入图上的坐标为(1.6,2.4),那么就要根据输入图上(1.6,2.4)周围的四个坐标点(1,2),(1,3),(2,2),(2,3)的像素值来填充。

四、STN模块的pytorch实现

这里我们假设Mnist数据集作为网络输入:

(1)首先定义Localisation net特征提取部分,为两个Conv层后接Maxpool和Relu操作:

在这里插入图片描述
(2)定义Localisation net的变换参数θ回归部分,为两层全连接层内接Relu:
在这里插入图片描述
(3)在nn.module的继承类中定义完整的STN模块操作:

在这里插入图片描述

五、空间变换网络的实际应用

在这里插入图片描述

1、STN作为网络的第一层

在这里插入图片描述

2、STN插入CNN 的中间层

在这里插入图片描述

六、评价

思想非常巧妙,因为卷积神经网络中的池化层(pooling layer)直接用一些max pooling 或者average pooling 的方法,将图片信息压缩,减少运算量提升准确率。

作者认为之前pooling的方法太过于暴力,直接将信息合并会导致关键信息无法识别出来,所以提出了一个叫空间转换器(spatial transformer)的模块,将图片中的的空间域信息做对应的空间变换,从而能将关键的信息提取出来。

Unlike pooling layers, where the receptive fields are fixed and local, the spatial transformer module is a dynamic mechanism that can actively spatially transform an image (or a feature map) by producing an appropriate transformation for each input sample.

在这里插入图片描述
空间转换器模型直观的实验图:

(a)列是原始的图片信息,其中第一个手写数字7没有做任何变换,第二个手写数字5,做了一定的旋转变化,而第三个手写数字6,加上了一些噪声信号;这些变化都是随机的
(b)列中的彩色边框是学习到的spatial transformer的框盒(bounding box),每一个框盒其实就是对应图片学习出来的一个spatial transformer

🪧©列中是通过spatial transformer转换之后的特征图,可以看出7的关键区域被选择出来,5被旋转成为了正向的图片,6的噪声信息没有被识别进入。

(d)列最终可以通过这些转换后的特征图来预测出中手写数字的数值。

🌱spatial transformer其实就是注意力机制的实现,因为训练出的spatial transformer能够找出图片信息中需要被关注的区域,同时这个transformer又能够具有旋转、缩放变换的功能,这样图片局部的重要信息能够通过变换而被框盒提取出来。🌱

参考

  • 原文链接:https://www.cnblogs.com/liaohuiqiang/p/9226335.html
  • https://blog.csdn.net/qq_43700729/article/details/136601998
  • 李弘毅讲 STN 网络:https://www.youtube.com/watch?v=SoCywZ1hZak
  • 知乎:https://zhuanlan.zhihu.com/p/41738716
  • https://blog.csdn.net/Rosemary_tu/article/details/84069878
  • https://ddelephant.blog.csdn.net/article/details/111303416?fromshare=blogdetail&sharetype=blogdetail&sharerId=111303416&sharerefer=PC&sharesource=Yaoyao2024&sharefrom=from_link

本人能力有限,上述内容如有理解不当的地方,欢迎与我讨论!


http://www.niftyadmin.cn/n/5744188.html

相关文章

RK3568 关于python依赖Miniconda3虚拟环境自启动

有关如何安装Miniconda3可以查看博客:RK3568 安装Miniconda3_miniconda3 aarch64 linux-CSDN博客 然后目前有个需求是需要开机自启动python脚本,但是需要依赖于虚拟环境,也就是说一起来就要打开虚拟环境并运行python脚本,一旦没有虚拟环境,python脚本就无法运行 解决办法…

Go语言的常用内置函数

文章目录 一、Strings包字符串处理包定义Strings包的基本用法Strconv包中常用函数 二、Time包三、Math包math包概述使用math包 四、随机数包(rand) 一、Strings包 字符串处理包定义 Strings包简介: 一般编程语言包含的字符串处理库功能区别…

使用Python简单实现客户端界面

服务端实现 import threading import timeimport wx from socket import socket, AF_INET, SOCK_STREAMclass LServer(wx.Frame):def __init__(self):wx.Frame.__init__(self, None, id1002, titleL服务器端界面, poswx.DefaultPosition, size(400, 450))# 窗口中添加面板pl …

mysql常见的一些配置项

MySQL 有许多配置选项,可以用来调整其行为以满足特定的需求。以下是一些常见的配置选项,除了大小写敏感之外,这些配置选项也经常被调整: 1. 字符集和排序规则 character_set_server: 设置服务器的默认字符集。collation_server:…

基于STM32的贪吃蛇游戏教学

引言 贪吃蛇是一款经典的电脑和手机游戏,它的简单性和趣味性使其成为很多人童年记忆的一部分。在本教程中,我们将创建一个基于STM32的贪吃蛇游戏项目。本项目将使用一个OLED显示屏来展示游戏画面,并使用按键来控制蛇的移动。通过本教程&#…

智能化健身房管理:Spring Boot与Vue的创新解决方案

作者介绍:✌️大厂全栈码农|毕设实战开发,专注于大学生项目实战开发、讲解和毕业答疑辅导。 🍅获取源码联系方式请查看文末🍅 推荐订阅精彩专栏 👇🏻 避免错过下次更新 Springboot项目精选实战案例 更多项目…

笔记--(网络3)、交换机、VLAN

交换机 交换机(Switch)意为“开关”是一种用于电(光)信号转发的网络设备。它可以为接入交换机的任意两个网络节点提供独享的电信号通路。最常见的交换机是以太网交换机。其他常见的还有电话语音交换机、光纤交换机等。 交换机的…

Docker使用相关记录

文章目录 查看本地镜像查看本地容器进入某个容器内部查看 Docker 运行中的日志查看本地镜像 docker images查看本地容器 # 查看本地运行的容器 docker ps # 查看本地所有的容器 docker ps -a进入某个容器内部 docker exec -it <name或id> /bin/bash退出bash, 输入exit…