计算机视觉中的注意力机制
计算机视觉中的注意力机制
一、注意力机制(attention mechanism)
attention机制可以它认为是一种资源分配的机制,可以理解为对于原本平均分配的资源根据attention对象的重要程度重新分配资源,重要的单位就多分一点,不重要或者不好的单位就少分一点,在深度神经网络的结构设计中,attention所要分配的资源基本上就是权重了
视觉注意力分为几种,核心思想是基于原有的数据找到其之间的关联性,然后突出其某些重要特征,有通道注意力,像素注意力,多阶注意力等,也有把NLP中的自注意力引入。
二、自注意力(self-attention)
参考文献:Attention is All you Need
参考资料:https://zhuanlan.zhihu.com/p/48508221
GitHub:https://github.com/huggingface/transformers
自注意力有时候也称为内部注意力,是一个与单个序列的不同位置相关的注意力机制,目的是计算序列的表达形式,因为解码器的位置不变性,以及在DETR中,每个像素不仅仅包含数值信息,并且每个像素的位置信息也很重要。
所有的编码器在结构上都是相同的,但它们没有共享参数。每个编码器都可以分解成两个子层:
在transformer中,每个encoder子层有Multi-head self-attention和position-wise FFN组成。
输入的每个单词通过嵌入的方式形成词向量,通过自注意进行编码,然后再送入FFN得出一个层级的编码。
解码器在结构上也是多个相同的堆叠而成,在有和encoder相似的结构的Multi-head self-attention和position-wise FFN,同时还多了一个注意力层用来关注输入句子的相关部分。
Self-Attention
Self-Attention是Transformer最核心的内容,可以理解位将队列和一组值与输入对应,即形成querry,key,value向output的映射,output可以看作是value的加权求和,加权值则是由Self-Attention来得出的。
具体实施细节如下:
在self-attention中,每个单词有3个不同的向量,它们分别是Query向量,Key向量和Value向量,长度均是64。它们是通过3个不同的权值矩阵由嵌入向量X乘以三个不同的权值矩阵得到,其中三个矩阵的尺寸也是相同的。均是512×64。
- 将输入单词转化成嵌入向量;
- 根据嵌入向量得到q,k,v三个向量;
- 为每个向量计算一个score:score=q×v;
- 为了梯度的稳定,Transformer使用了score归一化,即除以sqrt(dk);
- 对score施以softmax激活函数;
- softmax点乘Value值v,得到加权的每个输入向量的评分v;
- 相加之后得到最终的输出结果z。
矩阵形式的计算过程:
对于Multi-head self-attention,通过论文可以看出就是将单个点积注意力进行融合,两者相结合得出了transformer
具体的实施可以参照detr的models/transformer
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved <span class="hljs-string">""</span><span class="hljs-string">"</span> <span class="hljs-string">DETR Transformer class.</span> <span class="hljs-string">Copy-paste from torch.nn.Transformer with modifications:</span> <span class="hljs-string"> * positional encodings are passed in MHattention</span> <span class="hljs-string"> * extra LN at the end of encoder is removed</span> <span class="hljs-string"> * decoder returns a stack of activations from all decoding layers</span> <span class="hljs-string">"</span><span class="hljs-string">""</span> import copyfrom typing import Optional, Listimport torchimport torch.nn.functional <span class="hljs-keyword">as</span> F from torch import nn, Tensorclass Transformer(nn.Module): def __init__(<span class="hljs-keyword">self</span>, d_model=<span class="hljs-number">512</span>, nhead=<span class="hljs-number">8</span>, num_encoder_layers=<span class="hljs-number">6</span>, num_decoder_layers=<span class="hljs-number">6</span>, dim_feedforward=<span class="hljs-number">2048</span>, dropout=<span class="hljs-number">0.1</span>, activation=<span class="hljs-string">"relu"</span>, normalize_before=False, return_intermediate_dec=False): <span class="hljs-keyword">super</span>().__init__() encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) encoder_norm = nn.LayerNorm(d_model) <span class="hljs-keyword">if</span> normalize_before <span class="hljs-keyword">else</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">self</span>.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) decoder_norm = nn.LayerNorm(d_model) <span class="hljs-keyword">self</span>.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec) <span class="hljs-keyword">self</span>._reset_parameters() <span class="hljs-keyword">self</span>.d_model = d_model <span class="hljs-keyword">self</span>.nhead = nhead def _reset_parameters(<span class="hljs-keyword">self</span>): <span class="hljs-keyword">for</span> p <span class="hljs-keyword">in</span> <span class="hljs-keyword">self</span>.parameters(): <span class="hljs-keyword">if</span> p.dim() > <span class="hljs-number">1</span>: nn.init.xavier_uniform_(p) def forward(<span class="hljs-keyword">self</span>, src, mask, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(<span class="hljs-number">2</span>).permute(<span class="hljs-number">2</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>) pos_embed = pos_embed.flatten(<span class="hljs-number">2</span>).permute(<span class="hljs-number">2</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>) query_embed = query_embed.unsqueeze(<span class="hljs-number">1</span>).repeat(<span class="hljs-number">1</span>, bs, <span class="hljs-number">1</span>) mask = mask.flatten(<span class="hljs-number">1</span>) tgt = torch.zeros_like(query_embed) memory = <span class="hljs-keyword">self</span>.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = <span class="hljs-keyword">self</span>.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) <span class="hljs-keyword">return</span> hs.transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>), memory.permute(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">0</span>).view(bs, c, h, w) class TransformerEncoder(nn.Module): def __init__(<span class="hljs-keyword">self</span>, encoder_layer, num_layers, norm=<span class="hljs-literal">None</span>): <span class="hljs-keyword">super</span>().__init__() <span class="hljs-keyword">self</span>.layers = _get_clones(encoder_layer, num_layers) <span class="hljs-keyword">self</span>.num_layers = num_layers <span class="hljs-keyword">self</span>.norm = norm def forward(<span class="hljs-keyword">self</span>, src, mask: Optional[Tensor] = <span class="hljs-literal">None</span>, src_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>): output = src <span class="hljs-keyword">for</span> layer <span class="hljs-keyword">in</span> <span class="hljs-keyword">self</span>.layers: output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.norm is not <span class="hljs-literal">None</span>: output = <span class="hljs-keyword">self</span>.norm(output) <span class="hljs-keyword">return</span> output class TransformerDecoder(nn.Module): def __init__(<span class="hljs-keyword">self</span>, decoder_layer, num_layers, norm=<span class="hljs-literal">None</span>, return_intermediate=False): <span class="hljs-keyword">super</span>().__init__() <span class="hljs-keyword">self</span>.layers = _get_clones(decoder_layer, num_layers) <span class="hljs-keyword">self</span>.num_layers = num_layers <span class="hljs-keyword">self</span>.norm = norm <span class="hljs-keyword">self</span>.return_intermediate = return_intermediate def forward(<span class="hljs-keyword">self</span>, tgt, memory, tgt_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, tgt_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>, query_pos: Optional[Tensor] = <span class="hljs-literal">None</span>): output = tgt intermediate = [] <span class="hljs-keyword">for</span> layer <span class="hljs-keyword">in</span> <span class="hljs-keyword">self</span>.layers: output = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos) <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.return_intermediate: intermediate.append(<span class="hljs-keyword">self</span>.norm(output)) <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.norm is not <span class="hljs-literal">None</span>: output = <span class="hljs-keyword">self</span>.norm(output) <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.return_intermediate: intermediate.pop() intermediate.append(output) <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.return_intermediate: <span class="hljs-keyword">return</span> torch.stack(intermediate) <span class="hljs-keyword">return</span> output class TransformerEncoderLayer(nn.Module): def __init__(<span class="hljs-keyword">self</span>, d_model, nhead, dim_feedforward=<span class="hljs-number">2048</span>, dropout=<span class="hljs-number">0.1</span>, activation=<span class="hljs-string">"relu"</span>, normalize_before=False): <span class="hljs-keyword">super</span>().__init__() <span class="hljs-keyword">self</span>.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model <span class="hljs-keyword">self</span>.linear1 = nn.Linear(d_model, dim_feedforward) <span class="hljs-keyword">self</span>.dropout = nn.Dropout(dropout) <span class="hljs-keyword">self</span>.linear2 = nn.Linear(dim_feedforward, d_model) <span class="hljs-keyword">self</span>.norm1 = nn.LayerNorm(d_model) <span class="hljs-keyword">self</span>.norm2 = nn.LayerNorm(d_model) <span class="hljs-keyword">self</span>.dropout1 = nn.Dropout(dropout) <span class="hljs-keyword">self</span>.dropout2 = nn.Dropout(dropout) <span class="hljs-keyword">self</span>.activation = _get_activation_fn(activation) <span class="hljs-keyword">self</span>.normalize_before = normalize_before def with_pos_embed(<span class="hljs-keyword">self</span>, tensor, pos: Optional[Tensor]): <span class="hljs-keyword">return</span> tensor <span class="hljs-keyword">if</span> pos is <span class="hljs-literal">None</span> <span class="hljs-keyword">else</span> tensor + pos def forward_post(<span class="hljs-keyword">self</span>, src, src_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, src_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>): q = k = <span class="hljs-keyword">self</span>.with_pos_embed(src, pos) src2 = <span class="hljs-keyword">self</span>.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[<span class="hljs-number">0</span>] src = src + <span class="hljs-keyword">self</span>.dropout1(src2) src = <span class="hljs-keyword">self</span>.norm1(src) src2 = <span class="hljs-keyword">self</span>.linear2(<span class="hljs-keyword">self</span>.dropout(<span class="hljs-keyword">self</span>.activation(<span class="hljs-keyword">self</span>.linear1(src)))) src = src + <span class="hljs-keyword">self</span>.dropout2(src2) src = <span class="hljs-keyword">self</span>.norm2(src) <span class="hljs-keyword">return</span> src def forward_pre(<span class="hljs-keyword">self</span>, src, src_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, src_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>): src2 = <span class="hljs-keyword">self</span>.norm1(src) q = k = <span class="hljs-keyword">self</span>.with_pos_embed(src2, pos) src2 = <span class="hljs-keyword">self</span>.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[<span class="hljs-number">0</span>] src = src + <span class="hljs-keyword">self</span>.dropout1(src2) src2 = <span class="hljs-keyword">self</span>.norm2(src) src2 = <span class="hljs-keyword">self</span>.linear2(<span class="hljs-keyword">self</span>.dropout(<span class="hljs-keyword">self</span>.activation(<span class="hljs-keyword">self</span>.linear1(src2)))) src = src + <span class="hljs-keyword">self</span>.dropout2(src2) <span class="hljs-keyword">return</span> src def forward(<span class="hljs-keyword">self</span>, src, src_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, src_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>): <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.normalize_before: <span class="hljs-keyword">return</span> <span class="hljs-keyword">self</span>.forward_pre(src, src_mask, src_key_padding_mask, pos) <span class="hljs-keyword">return</span> <span class="hljs-keyword">self</span>.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__(<span class="hljs-keyword">self</span>, d_model, nhead, dim_feedforward=<span class="hljs-number">2048</span>, dropout=<span class="hljs-number">0.1</span>, activation=<span class="hljs-string">"relu"</span>, normalize_before=False): <span class="hljs-keyword">super</span>().__init__() <span class="hljs-keyword">self</span>.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) <span class="hljs-keyword">self</span>.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model <span class="hljs-keyword">self</span>.linear1 = nn.Linear(d_model, dim_feedforward) <span class="hljs-keyword">self</span>.dropout = nn.Dropout(dropout) <span class="hljs-keyword">self</span>.linear2 = nn.Linear(dim_feedforward, d_model) <span class="hljs-keyword">self</span>.norm1 = nn.LayerNorm(d_model) <span class="hljs-keyword">self</span>.norm2 = nn.LayerNorm(d_model) <span class="hljs-keyword">self</span>.norm3 = nn.LayerNorm(d_model) <span class="hljs-keyword">self</span>.dropout1 = nn.Dropout(dropout) <span class="hljs-keyword">self</span>.dropout2 = nn.Dropout(dropout) <span class="hljs-keyword">self</span>.dropout3 = nn.Dropout(dropout) <span class="hljs-keyword">self</span>.activation = _get_activation_fn(activation) <span class="hljs-keyword">self</span>.normalize_before = normalize_before def with_pos_embed(<span class="hljs-keyword">self</span>, tensor, pos: Optional[Tensor]): <span class="hljs-keyword">return</span> tensor <span class="hljs-keyword">if</span> pos is <span class="hljs-literal">None</span> <span class="hljs-keyword">else</span> tensor + pos def forward_post(<span class="hljs-keyword">self</span>, tgt, memory, tgt_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, tgt_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>, query_pos: Optional[Tensor] = <span class="hljs-literal">None</span>): q = k = <span class="hljs-keyword">self</span>.with_pos_embed(tgt, query_pos) tgt2 = <span class="hljs-keyword">self</span>.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[<span class="hljs-number">0</span>] tgt = tgt + <span class="hljs-keyword">self</span>.dropout1(tgt2) tgt = <span class="hljs-keyword">self</span>.norm1(tgt) tgt2 = <span class="hljs-keyword">self</span>.multihead_attn(query=<span class="hljs-keyword">self</span>.with_pos_embed(tgt, query_pos), key=<span class="hljs-keyword">self</span>.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[<span class="hljs-number">0</span>] tgt = tgt + <span class="hljs-keyword">self</span>.dropout2(tgt2) tgt = <span class="hljs-keyword">self</span>.norm2(tgt) tgt2 = <span class="hljs-keyword">self</span>.linear2(<span class="hljs-keyword">self</span>.dropout(<span class="hljs-keyword">self</span>.activation(<span class="hljs-keyword">self</span>.linear1(tgt)))) tgt = tgt + <span class="hljs-keyword">self</span>.dropout3(tgt2) tgt = <span class="hljs-keyword">self</span>.norm3(tgt) <span class="hljs-keyword">return</span> tgt def forward_pre(<span class="hljs-keyword">self</span>, tgt, memory, tgt_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, tgt_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>, query_pos: Optional[Tensor] = <span class="hljs-literal">None</span>): tgt2 = <span class="hljs-keyword">self</span>.norm1(tgt) q = k = <span class="hljs-keyword">self</span>.with_pos_embed(tgt2, query_pos) tgt2 = <span class="hljs-keyword">self</span>.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[<span class="hljs-number">0</span>] tgt = tgt + <span class="hljs-keyword">self</span>.dropout1(tgt2) tgt2 = <span class="hljs-keyword">self</span>.norm2(tgt) tgt2 = <span class="hljs-keyword">self</span>.multihead_attn(query=<span class="hljs-keyword">self</span>.with_pos_embed(tgt2, query_pos), key=<span class="hljs-keyword">self</span>.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[<span class="hljs-number">0</span>] tgt = tgt + <span class="hljs-keyword">self</span>.dropout2(tgt2) tgt2 = <span class="hljs-keyword">self</span>.norm3(tgt) tgt2 = <span class="hljs-keyword">self</span>.linear2(<span class="hljs-keyword">self</span>.dropout(<span class="hljs-keyword">self</span>.activation(<span class="hljs-keyword">self</span>.linear1(tgt2)))) tgt = tgt + <span class="hljs-keyword">self</span>.dropout3(tgt2) <span class="hljs-keyword">return</span> tgt def forward(<span class="hljs-keyword">self</span>, tgt, memory, tgt_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, tgt_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, memory_key_padding_mask: Optional[Tensor] = <span class="hljs-literal">None</span>, pos: Optional[Tensor] = <span class="hljs-literal">None</span>, query_pos: Optional[Tensor] = <span class="hljs-literal">None</span>): <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.normalize_before: <span class="hljs-keyword">return</span> <span class="hljs-keyword">self</span>.forward_pre(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) <span class="hljs-keyword">return</span> <span class="hljs-keyword">self</span>.forward_post(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)def _get_clones(module, N): <span class="hljs-keyword">return</span> nn.ModuleList([copy.deepcopy(module) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(N)]) def build_transformer(args): <span class="hljs-keyword">return</span> Transformer( d_model=args.hidden_dim, dropout=args.dropout, nhead=args.nheads, dim_feedforward=args.dim_feedforward, num_encoder_layers=args.enc_layers, num_decoder_layers=args.dec_layers, normalize_before=args.pre_norm, return_intermediate_dec=True, )def _get_activation_fn(activation): <span class="hljs-string">""</span><span class="hljs-string">"Return an activation function given a string"</span><span class="hljs-string">""</span> <span class="hljs-keyword">if</span> activation == <span class="hljs-string">"relu"</span>: <span class="hljs-keyword">return</span> F.relu <span class="hljs-keyword">if</span> activation == <span class="hljs-string">"gelu"</span>: <span class="hljs-keyword">return</span> F.gelu <span class="hljs-keyword">if</span> activation == <span class="hljs-string">"glu"</span>: <span class="hljs-keyword">return</span> F.glu raise RuntimeError(F<span class="hljs-string">"activation should be relu/gelu, not {activation}."</span>) |
三、软注意力(soft-attention)
软注意力是一个[0,1]间的连续分布问题,更加关注区域或者通道,软注意力是确定性注意力,学习完成后可以通过网络生成,并且是可微的,可以通过神经网络计算出梯度并且可以前向传播和后向反馈来学习得到注意力的权重。
1、空间域注意力(spatial transformer network)
论文地址:http://papers.nips.cc/paper/5854-spatial-transformer-networks
GitHub地址:https://github.com/fxia22/stn.pytorch
空间区域注意力可以理解为让神经网络在看哪里。通过注意力机制,将原始图片中的空间信息变换到另一个空间中并保留了关键信息,在很多现有的方法中都有使用这种网络,自己接触过的一个就是ALPHA Pose。spatial transformer其实就是注意力机制的实现,因为训练出的spatial transformer能够找出图片信息中需要被关注的区域,同时这个transformer又能够具有旋转、缩放变换的功能,这样图片局部的重要信息能够通过变换而被框盒提取出来。
主要在于空间变换矩阵的学习
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
class STN(Module): def __init__(<span class="hljs-keyword">self</span>, layout = <span class="hljs-symbol">'BHWD</span>'): <span class="hljs-keyword">super</span>(STN, <span class="hljs-keyword">self</span>).__init__() <span class="hljs-keyword">if</span> layout == <span class="hljs-symbol">'BHWD</span>': <span class="hljs-keyword">self</span>.f = STNFunction() <span class="hljs-keyword">else</span>: <span class="hljs-keyword">self</span>.f = STNFunctionBCHW() def forward(<span class="hljs-keyword">self</span>, input1, input2): <span class="hljs-keyword">return</span> <span class="hljs-keyword">self</span>.f(input1, input2) class STNFunction(Function): def forward(<span class="hljs-keyword">self</span>, input1, input2): <span class="hljs-keyword">self</span>.input1 = input1 <span class="hljs-keyword">self</span>.input2 = input2 <span class="hljs-keyword">self</span>.device_c = ffi.new(<span class="hljs-string">"int *"</span>) output = torch.zeros(input1.size()[<span class="hljs-number">0</span>], input2.size()[<span class="hljs-number">1</span>], input2.size()[<span class="hljs-number">2</span>], input1.size()[<span class="hljs-number">3</span>]) #print(<span class="hljs-symbol">'decice</span> %d' % torch.cuda.current_device()) <span class="hljs-keyword">if</span> input1.is_cuda: <span class="hljs-keyword">self</span>.device = torch.cuda.current_device() <span class="hljs-keyword">else</span>: <span class="hljs-keyword">self</span>.device = -<span class="hljs-number">1</span> <span class="hljs-keyword">self</span>.device_c[<span class="hljs-number">0</span>] = <span class="hljs-keyword">self</span>.device <span class="hljs-keyword">if</span> not input1.is_cuda: my_lib.BilinearSamplerBHWD_updateOutput(input1, input2, output) <span class="hljs-keyword">else</span>: output = output.cuda(<span class="hljs-keyword">self</span>.device) my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, <span class="hljs-keyword">self</span>.device_c) <span class="hljs-keyword">return</span> output def backward(<span class="hljs-keyword">self</span>, grad_output): grad_input1 = torch.zeros(<span class="hljs-keyword">self</span>.input1.size()) grad_input2 = torch.zeros(<span class="hljs-keyword">self</span>.input2.size()) #print(<span class="hljs-symbol">'backward</span> decice %d' % <span class="hljs-keyword">self</span>.device) <span class="hljs-keyword">if</span> not grad_output.is_cuda: my_lib.BilinearSamplerBHWD_updateGradInput(<span class="hljs-keyword">self</span>.input1, <span class="hljs-keyword">self</span>.input2, grad_input1, grad_input2, grad_output) <span class="hljs-keyword">else</span>: grad_input1 = grad_input1.cuda(<span class="hljs-keyword">self</span>.device) grad_input2 = grad_input2.cuda(<span class="hljs-keyword">self</span>.device) my_lib.BilinearSamplerBHWD_updateGradInput_cuda(<span class="hljs-keyword">self</span>.input1, <span class="hljs-keyword">self</span>.input2, grad_input1, grad_input2, grad_output, <span class="hljs-keyword">self</span>.device_c) <span class="hljs-keyword">return</span> grad_input1, grad_input2 |
2、通道注意力(Channel Attention,CA)
通道注意力可以理解为让神经网络在看什么,典型的代表是SENet。卷积网络的每一层都有好多卷积核,每个卷积核对应一个特征通道,相对于空间注意力机制,通道注意力在于分配各个卷积通道之间的资源,分配粒度上比前者大了一个级别。
论文:Squeeze-and-Excitation Networks
https://arxiv.org/abs/1709.01507
GitHub地址:https://github.com/moskomule/senet.pytorch
Squeeze操作:将各通道的全局空间特征作为该通道的表示,使用全局平均池化生成各通道的统计量
Excitation操作:学习各通道的依赖程度,并根据依赖程度对不同的特征图进行调整,得到最后的输出,需要考察各通道的依赖程度
整体的结构如图所示:
卷积层的输出并没有考虑对各通道的依赖,SEBlock的目的在于然根网络选择性的增强信息量最大的特征,是的后期处理充分利用这些特征并抑制无用的特征。
SE-Inception Module
SE-ResNet Module
- 将输入特征进行 Global avgpooling,得到1×1×Channel
- 然后bottleneck特征交互一下,先压缩channel数,再重构回channel数
- 最后接个sigmoid,生成channel间0~1的attention weights,最后scale乘回原输入特征
SE-ResNet的SE-Block
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
<span class="hljs-keyword">class</span> SEBasicBlock(nn.Module): expansion = <span class="hljs-number">1</span> def __init__(<span class="hljs-keyword">self</span>, inplanes, planes, stride=<span class="hljs-number">1</span>, downsample=None, groups=<span class="hljs-number">1</span>, base_width=<span class="hljs-number">64</span>, dilation=<span class="hljs-number">1</span>, norm_layer=None, *, reduction=<span class="hljs-number">16</span>): <span class="hljs-keyword">super</span>(SEBasicBlock, <span class="hljs-keyword">self</span>).__init__() <span class="hljs-keyword">self</span>.conv1 = conv3x3(inplanes, planes, stride) <span class="hljs-keyword">self</span>.bn1 = nn.BatchNorm2d(planes) <span class="hljs-keyword">self</span>.relu = nn.ReLU(inplace=True) <span class="hljs-keyword">self</span>.conv2 = conv3x3(planes, planes, <span class="hljs-number">1</span>) <span class="hljs-keyword">self</span>.bn2 = nn.BatchNorm2d(planes) <span class="hljs-keyword">self</span>.se = SELayer(planes, reduction) <span class="hljs-keyword">self</span>.downsample = downsample <span class="hljs-keyword">self</span>.stride = stride def forward(<span class="hljs-keyword">self</span>, x): residual = x <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.conv1(x) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.bn1(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.relu(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.conv2(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.bn2(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.se(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">if</span> <span class="hljs-keyword">self</span>.downsample is not None: residual = <span class="hljs-keyword">self</span>.downsample(x) <span class="hljs-keyword">out</span> += residual <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.relu(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">return</span> out class SELayer(nn.Module): def __init__(<span class="hljs-keyword">self</span>, channel, reduction=<span class="hljs-number">16</span>): <span class="hljs-keyword">super</span>(SELayer, <span class="hljs-keyword">self</span>).__init__() <span class="hljs-keyword">self</span>.avg_pool = nn.AdaptiveAvgPool2d(<span class="hljs-number">1</span>) <span class="hljs-keyword">self</span>.fc = nn.Sequential( nn.Linear(channel, channel <span class="hljs-comment">// reduction, bias=False),</span> <span class="hljs-comment"> nn.ReLU(inplace=True),</span> <span class="hljs-comment"> nn.Linear(channel // reduction, channel, bias=False),</span> <span class="hljs-comment"> nn.Sigmoid()</span> <span class="hljs-comment"> )</span> <span class="hljs-comment"> def forward(self, x):</span> <span class="hljs-comment"> b, c, _, _ = x.size()</span> <span class="hljs-comment"> y = self.avg_pool(x).view(b, c)</span> <span class="hljs-comment"> y = self.fc(y).view(b, c, 1, 1)</span> <span class="hljs-comment"> return x * y.expand_as(x)</span> |
ResNet的Basic Block
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
<span class="hljs-keyword">class</span> BasicBlock(nn.Module): def __init__(<span class="hljs-keyword">self</span>, inplanes, planes, stride=<span class="hljs-number">1</span>): <span class="hljs-keyword">super</span>(BasicBlock, <span class="hljs-keyword">self</span>).__init__() <span class="hljs-keyword">self</span>.conv1 = conv3x3(inplanes, planes, stride) <span class="hljs-keyword">self</span>.bn1 = nn.BatchNorm2d(planes) <span class="hljs-keyword">self</span>.relu = nn.ReLU(inplace=True) <span class="hljs-keyword">self</span>.conv2 = conv3x3(planes, planes) <span class="hljs-keyword">self</span>.bn2 = nn.BatchNorm2d(planes) <span class="hljs-keyword">if</span> inplanes != planes: <span class="hljs-keyword">self</span>.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=<span class="hljs-number">1</span>, stride=stride, bias=False), nn.BatchNorm2d(planes)) <span class="hljs-keyword">else</span>: <span class="hljs-keyword">self</span>.downsample = lambda x: x <span class="hljs-keyword">self</span>.stride = stride def forward(<span class="hljs-keyword">self</span>, x): residual = <span class="hljs-keyword">self</span>.downsample(x) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.conv1(x) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.bn1(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.relu(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.conv2(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.bn2(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">out</span> += residual <span class="hljs-keyword">out</span> = <span class="hljs-keyword">self</span>.relu(<span class="hljs-keyword">out</span>) <span class="hljs-keyword">return</span> <span class="hljs-keyword">out</span> |
两者的差别主要体现在多了一个SElayer,详细可以查看源码
3、混合域模型(融合空间域和通道域注意力)
(1)论文:Residual Attention Network for image classification(CVPR 2017 Open Access Repository)
http://openaccess.thecvf.com/content_cvpr_2017/html/Wang_Residual_Attention_Network_CVPR_2017_paper.html
文章中注意力的机制是软注意力基本的加掩码(mask)机制,但是不同的是,这种注意力机制的mask借鉴了残差网络的想法,不只根据当前网络层的信息加上mask,还把上一层的信息传递下来,这样就防止mask之后的信息量过少引起的网络层数不能堆叠很深的问题。
该文章的注意力机制的创新点在于提出了残差注意力学习(residual attention learning),不仅只把mask之后的特征张量作为下一层的输入,同时也将mask之前的特征张量作为下一层的输入,这时候可以得到的特征更为丰富,从而能够更好的注意关键特征。同时采用三阶注意力模块来构成整个的注意力。
(2)Dual Attention Network for Scene Segmentation(CVPR 2019 Open Access Repository)
http://openaccess.thecvf.com/content_CVPR_2019/html/Fu_Dual_Attention_Network_for_Scene_Segmentation_CVPR_2019_paper.html
4、Non-Local
论文:non-local neural networks(CVPR 2018 Open Access Repository)
http://openaccess.thecvf.com/content_cvpr_2018/html/Wang_Non-Local_Neural_Networks_CVPR_2018_paper.html
GitHub地址:https://github.com/AlexHex7/Non-local_pytorch
Local这个词主要是针对感受野(receptive field)来说的。以单一的卷积操作为例,它的感受野大小就是卷积核大小,而我们一般都选用3*3,5*5之类的卷积核,它们只考虑局部区域,因此都是local的运算。同理,池化(Pooling)也是。相反的,non-local指的就是感受野可以很大,而不是一个局部领域。全连接就是non-local的,而且是global的。但是全连接带来了大量的参数,给优化带来困难。卷积层的堆叠可以增大感受野,但是如果看特定层的卷积核在原图上的感受野,它毕竟是有限的。这是local运算不能避免的。然而有些任务,它们可能需要原图上更多的信息,比如attention。如果在某些层能够引入全局的信息,就能很好地解决local操作无法看清全局的情况,为后面的层带去更丰富的信息。
文章定义的对于神经网络通用的Non-Local计算如下所示:
如果按照上面的公式,用for循环实现肯定是很慢的。此外,如果在尺寸很大的输入上应用non-local layer,也是计算量很大的。后者的解决方案是,只在高阶语义层中引入non-local layer。还可以通过对embedding(θ,ϕ,g)的结果加pooling层来进一步地减少计算量。
- 首先对输入的 feature map X 进行线性映射(通过1×1卷积,来压缩通道数),然后得到θ,ϕ,g特征
- 通过reshape操作,强行合并上述的三个特征除通道数外的维度,然后对 进行矩阵点乘操作,得到类似协方差矩阵的东西(这个过程很重要,计算出特征中的自相关性,即得到每帧中每个像素对其他所有帧所有像素的关系)
- 然后对自相关特征 以列or以行(具体看矩阵 g 的形式而定) 进行 Softmax 操作,得到0~1的weights,这里就是我们需要的 Self-attention 系数
- 最后将 attention系数,对应乘回特征矩阵g中,然后再上扩channel 数,与原输入feature map X残差
5、位置注意力(position-wise attention)
论文:CCNet: Criss-Cross Attention for Semantic Segmentation(ICCV 2019 Open Access Repository)
http://openaccess.thecvf.com/content_ICCV_2019/html/Huang_CCNet_Criss-Cross_Attention_for_Semantic_Segmentation_ICCV_2019_paper.html
Github地址:https://github.com/speedinghzl/CCNet
本篇文章的亮点在于用了巧妙的方法减少了参数量。在上面的DANet中,attention map计算的是所有像素与所有像素之间的相似性,空间复杂度为(HxW)x(HxW),而本文采用了criss-cross思想,只计算每个像素与其同行同列即十字上的像素的相似性,通过进行循环(两次相同操作),间接计算到每个像素与每个像素的相似性,将空间复杂度降为(HxW)x(H+W-1)
在计算矩阵相乘时每个像素只抽取特征图中对应十字位置的像素进行点乘,计算相似度。和non-local的方法相比极大的降低了计算量,同时采用二阶注意力,能够从所有像素中获取全图像的上下文信息,以生成具有密集且丰富的上下文信息的新特征图。在计算矩阵相乘时每个像素只抽取特征图中对应十字位置的像素进行点乘,计算相似度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
def _check_contiguous(*args): <span class="hljs-keyword">if</span> not all([mod <span class="hljs-keyword">is</span> None or mod.is_contiguous() <span class="hljs-keyword">for</span> mod <span class="hljs-keyword">in</span> args]): raise ValueError(<span class="hljs-string">"Non-contiguous input"</span>) <span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CA_Weight</span></span>(autograd.Function): <span class="hljs-meta">@staticmethod</span> def forward(ctx, t, f): # Save context n, c, h, w = t.size() size = (n, h+w-<span class="hljs-number">1</span>, h, w) weight = torch.zeros(size, dtype=t.dtype, layout=t.layout, device=t.device) _ext.ca_forward_cuda(t, f, weight) # Output ctx.save_for_backward(t, f) <span class="hljs-keyword">return</span> weight <span class="hljs-meta">@staticmethod</span> <span class="hljs-meta">@once_differentiable</span> def backward(ctx, dw): t, f = ctx.saved_tensors dt = torch.zeros_like(t) df = torch.zeros_like(f) _ext.ca_backward_cuda(dw.contiguous(), t, f, dt, df) _check_contiguous(dt, df) <span class="hljs-keyword">return</span> dt, df class CA_Map(autograd.Function): <span class="hljs-meta">@staticmethod</span> def forward(ctx, weight, g): # Save context <span class="hljs-keyword">out</span> = torch.zeros_like(g) _ext.ca_map_forward_cuda(weight, g, <span class="hljs-keyword">out</span>) # Output ctx.save_for_backward(weight, g) <span class="hljs-keyword">return</span> <span class="hljs-keyword">out</span> <span class="hljs-meta">@staticmethod</span> <span class="hljs-meta">@once_differentiable</span> def backward(ctx, dout): weight, g = ctx.saved_tensors dw = torch.zeros_like(weight) dg = torch.zeros_like(g) _ext.ca_map_backward_cuda(dout.contiguous(), weight, g, dw, dg) _check_contiguous(dw, dg) <span class="hljs-keyword">return</span> dw, dg ca_weight = CA_Weight.applyca_map = CA_Map.applyclass CrissCrossAttention(nn.Module): <span class="hljs-string">""" Criss-Cross Attention Module"""</span> def __init__(self,in_dim): <span class="hljs-keyword">super</span>(CrissCrossAttention,self).__init__() self.chanel_in = in_dim self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim<span class="hljs-comment">//8 , kernel_size= 1)</span> <span class="hljs-comment"> self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)</span> <span class="hljs-comment"> self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)</span> <span class="hljs-comment"> self.gamma = nn.Parameter(torch.zeros(1))</span> <span class="hljs-comment"> def forward(self,x):</span> <span class="hljs-comment"> proj_query = self.query_conv(x)</span> <span class="hljs-comment"> proj_key = self.key_conv(x)</span> <span class="hljs-comment"> proj_value = self.value_conv(x)</span> <span class="hljs-comment"> energy = ca_weight(proj_query, proj_key)</span> <span class="hljs-comment"> attention = F.softmax(energy, 1)</span> <span class="hljs-comment"> out = ca_map(attention, proj_value)</span> <span class="hljs-comment"> out = self.gamma*out + x</span> <span class="hljs-comment"> return out</span> <span class="hljs-comment">__all__ = ["CrissCrossAttention", "ca_weight", "ca_map"]</span> |
三、强注意力(hard attention)
0/1问题,哪些被attention,哪些不被attention。更加关注点,图像中的每个点都可能延伸出注意力,同时强注意力是一个随机预测的过程,更加强调动态变化,并且是不可微,所以训练过程往往通过增强学习。
参考资料
https://blog.csdn.net/xys430381_1/article/details/89323444
https://zhuanlan.zhihu.com/p/33345791
https://zhuanlan.zhihu.com/p/54150694
https://m.toutiaocdn.com/i6835863778312061452/?app=news_article×tamp=1591829585&use_new_style=1&req_id=202006110653050100140470252D91C959&group_id=6835863778312061452&tt_from=android_share&utm_medium=toutiao_android&utm_campaign=client_share
转载请注明:徐自远的乱七八糟小站 » 计算机视觉中的注意力机制