【Block总结】门控轴向注意力Gated Axial-Attention|即插即用

论文信息

标题: Medical Transformer: Gated Axial-Attention for Medical Image Segmentation
论文链接: https://arxiv.org/pdf/2102.10662
GitHub链接: https://github.com/jeya-maria-jose/Medical-Transformer
在这里插入图片描述

创新点

  1. 门控轴向注意力机制: 该机制通过引入可学习的门控参数,增强了模型对长距离依赖关系的捕捉能力,克服了传统自注意力模型的不足。

  2. 局部-全局(LoGo)训练策略: 结合局部特征和全局上下文信息的学习,提升了模型在小样本数据集上的表现。

  3. 适应性强: MedT在多个医学图像分割任务中表现优异,尤其是在样本数量较少的情况下,显示出较强的适应性。
    在这里插入图片描述

方法

  1. 模型架构: MedT结合了Transformer的优势与医学图像分割的需求,采用了门控轴向注意力机制来处理输入特征图。

  2. 门控机制: 通过四个门控参数(Go, Gk, Gv1, Gv2)控制位置编码对键、查询和值的影响,从而增强模型对空间信息的敏感性。

  3. 训练策略: 采用局部-全局训练策略,优化了特征提取和分割性能,特别适合小样本数据集。

门控机制模块

在论文《Medical Transformer: Gated Axial-Attention for Medical Image Segmentation》中,门控机制模块是一个重要的创新,旨在提升医学图像分割的性能。以下是对该模块的详细解读:

1. 门控机制的基本概念

门控机制通过引入可学习的参数来控制信息流的权重,使模型能够根据输入数据的特征动态调整不同信息的影响力。这种机制在处理复杂数据时,尤其是在医学图像分割中,能够显著提高模型的灵活性和适应性。

2. 门控轴向注意力的实现

在MedT模型中,门控机制主要体现在门控轴向注意力(Gated Axial-Attention)模块中,其实现过程包括以下几个关键步骤:

  • 引入门控参数: 该机制使用四个门控参数(Go, Gk, Gv1, Gv2)来控制位置编码对键(Key)、查询(Query)和值(Value)的影响。这使得模型能够根据不同数据集的特性进行调整。

  • 修改自注意力公式: 在自注意力计算中加入门控机制,以调整位置编码的权重,从而优化对非局部上下文的编码能力。

  • 多头自注意力分组: 输入特征被划分为多个组,每个组独立计算注意力。这种设计允许模型在不同特征子空间中关注局部关系和长距离依赖。

  • 轴向分解的注意力机制: 该机制只在一个轴(宽度或高度)上计算注意力,避免了全局自注意力的高计算量。这种分解将计算复杂度从二维降低到一维,显著提高了效率。

  • 相对位置编码: 通过引入相对位置编码,增强了模型对位置信息的敏感性,使得每个特征位置能够感知其在输入序列中的相对位置。

3. 门控机制的优势

  • 提高模型的鲁棒性: 动态调整信息流的权重使得模型在面对不同类型的医学图像时能够更好地适应,从而提高了整体的鲁棒性。

  • 优化计算效率: 门控机制的引入使得模型在处理大规模数据时能够更高效地学习,减少了计算资源的消耗。

效果

实验结果表明,MedT在多个医学图像分割数据集上均取得了优异的性能,尤其是在处理小样本数据时,其表现明显优于传统的卷积神经网络(CNN)和其他Transformer架构。

实验结果

  1. 数据集: 论文中使用了多个医学图像分割数据集进行验证,包括CT和MRI图像。

  2. 性能评估: 通过与现有的最先进模型进行对比,MedT在分割准确性、鲁棒性和计算效率等方面均表现出色。

  3. 结果展示: 实验结果通过定量指标(如Dice系数、IoU等)和定性分析(分割结果可视化)进行了全面展示,证明了模型的有效性。

总结

该论文通过提出医学Transformer(MedT)及其门控轴向注意力机制,为医学图像分割提供了一种新的解决方案。MedT在多个医学图像分割任务中表现优异,尤其在小样本数据集上显示出强大的适应性和有效性。未来的研究可以进一步探索该模型在其他医学图像任务中的应用,以及如何优化其在大规模数据集上的表现。

代码

import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nn
import math



def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class AxialAttention(nn.Module):
    def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,stride=1, bias=False, width=False):
        assert (in_planes % groups == 0) and (out_planes % groups == 0)
        super(AxialAttention, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.groups = groups
        self.group_planes = out_planes // groups
        self.kernel_size = kernel_size
        self.stride = stride
        self.bias = bias
        self.width = width

        # Multi-head self attention
        self.qkv_transform = nn.Conv1d(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                           padding=0, bias=False)
        self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
        self.bn_similarity = nn.BatchNorm2d(groups * 3)

        self.bn_output = nn.BatchNorm1d(out_planes * 2)

        # Position embedding
        self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
        query_index = torch.arange(kernel_size).unsqueeze(0)
        key_index = torch.arange(kernel_size).unsqueeze(1)
        relative_index = key_index - query_index + kernel_size - 1
        self.register_buffer('flatten_index', relative_index.view(-1))
        if stride > 1:
            self.pooling = nn.AvgPool2d(stride, stride=stride)

        self.reset_parameters()

    def forward(self, x):
        # pdb.set_trace()
        if self.width:
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 3, 1, 2)  # N, W, C, H
        N, W, C, H = x.shape
        x = x.contiguous().view(N * W, C, H)

        # Transformations
        qkv = self.bn_qkv(self.qkv_transform(x))
        q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H),
                              [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

        # Calculate position embedding
        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2,
                                                                                       self.kernel_size,
                                                                                       self.kernel_size)
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings,
                                                            [self.group_planes // 2, self.group_planes // 2,
                                                             self.group_planes], dim=0)

        qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
        kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)

        qk = torch.einsum('bgci, bgcj->bgij', q, k)

        stacked_similarity = torch.cat([qk, qr, kr], dim=1)
        stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
        # stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
        # (N, groups, H, H, W)
        similarity = F.softmax(stacked_similarity, dim=3)
        sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
        sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
        stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
        output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)

        if self.width:
            output = output.permute(0, 2, 1, 3)
        else:
            output = output.permute(0, 2, 3, 1)

        if self.stride > 1:
            output = self.pooling(output)

        return output

    def reset_parameters(self):
        self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
        # nn.init.uniform_(self.relative, -0.1, 0.1)
        nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))


if __name__ == "__main__":
    dim=64
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, channels,height, width)
    x = torch.randn(2,dim,40,40).to(device)
    # 初始化 FullyAttentionalBlock 模块

    block = AxialAttention(dim,dim,kernel_size=40) # kernel_size为height或者width
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

输出结果:
在这里插入图片描述


http://www.niftyadmin.cn/n/5841095.html

相关文章

word2vec 实战应用介绍

Word2Vec 是一种由 Google 在 2013 年推出的重要词嵌入模型,通过将单词映射为低维向量,实现了对自然语言处理任务的高效支持。其核心思想是利用深度学习技术,通过训练大量文本数据,将单词表示为稠密的向量形式,从而捕捉单词之间的语义和语法关系。以下是关于 Word2Vec 实战…

Ubuntu 22.04系统安装部署Kubernetes v1.29.13集群

Ubuntu 22.04系统安装部署Kubernetes v1.29.13集群 简介Kubernetes 的工作流程概述Kubernetes v1.29.13 版本Ubuntu 22.04 系统安装部署 Kubernetes v1.29.13 集群 1 环境准备1.1 集群IP规划1.2 初始化步骤(各个节点都需执行)1.2.1 主机名与IP地址解析1.…

TypeScript 运算符

TypeScript 运算符 TypeScript 作为 JavaScript 的超集,在 JavaScript 的基础上增加了静态类型系统,使得开发大型应用更加容易和维护。在 TypeScript 中,运算符是执行特定数学或逻辑运算的符号。本文将详细介绍 TypeScript 中常见的运算符,并对其使用方法进行详细阐述。 …

GPIO配置通用输出,推挽输出,开漏输出的作用,以及输出上下拉起到的作用

通用输出说明: ①输出原理: 对输出数据寄存器的对应位写0 或 1,就可以控制对应编号的IO口输出低/高电平 ②输出类型 推挽输出:IO口可以输出高电平,也可以输出低电平 开漏输出:IO口只能输出低电平 所以…

Haskell语言的多线程编程

Haskell语言的多线程编程 Haskell是一种基于函数式编程范式的编程语言,以其强大的类型系统和懒惰求值著称。近年来,随着多核处理器的发展,多线程编程变得日益重要。虽然Haskell最初并不是为了多线程而设计,但它的设计理念和工具集…

在 Ubuntu 中使用 FastAPI 创建一个简单的 Web 应用程序

FastAPI 是一个现代、快速且基于 Python 的 Web 框架,特别适合构建 API。本文将指导你如何在 Ubuntu 系统中安装 FastAPI 并创建一个简单的“Hello World”应用。 1. 安装必要的软件和依赖 在开始之前,请确保你的系统已经安装了以下工具: P…

java SSM框架 商城系统源码(含数据库脚本)

商城购物功能,项目代码,mysql脚本,html等静态资源在压缩包里面 注册界面 登陆界面 商城首页 文件列表 shop/.classpath , 1768 shop/.project , 1440 shop/.settings/.jsdtscope , 639 shop/.settings/org.eclipse.core.resources.prefs , …

如果通过认证方式调用Sf的api

导读 OAuth 2.0:是一个开放的授权框架,当用户想要访问Service Provider提供的资源时,OAuth客户端可以从IdP(Identity Provider)获得授权而不需要获取用户名和密码就可以访问该资源题。 作者:vivi,来源:osinnovation …