本文介绍BIT-CD模型复现,该模型将Transformer引入遥感图像变化检测。以改进的ResNet18孪生网络为Backbone提取特征,经Bitemporal Image Transformer处理,再由Prediction Head生成变化预测。还包含环境数据准备、模型训练验证及TIPC测试等内容,在LEVIR-CD测试集F1-Score达89.31%。
☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

论文简介
BIT-CD的结构如下图所示。作者通过一个CNN骨干网络(ResNet)从输入图像对中提取高层语义特征,并且使用空间注意力将每个时间特征图转换成一组紧凑的语义tokens。然后使用一个transformer编码器在两个tokens集中建模上下文,得到了富有上下文的tokens被一个连体transformer解码器重新投影到像素级空间,以增强原始像素级特征。最终,作者从两个细化的特征图计算特征差异图像(FDIs),然后将它们送到浅层CNN以产生像素级变化预测。
注:以下BIT的模型分析参考CSDN博客:VisionTransformer(三)BIT—— 基于孪生网络的变化检测结构分析
孪生网络的概念在【第六期论文复现赛-变化检测】SNUNet-CD已经做出解释,不再赘述
BIT不同于SNUNet-CD使用UNet++的结构做主干提取变化的特征,而是使用了对ResNet18改进后的网络,主要做了以下两点修改:
这样的做法能够减少空间细节的损失,并一定程度上增大感受野,最后BIT模型取得性能最好的是使用前四个stages,然后直接加上最后的上采样和输出卷积的Backbone,论文中的命名为ResNet18_S4
代码如下
class Backbone(nn.Layer, KaimingInitMixin):
def __init__(self,
in_ch,
out_ch=32,
arch='resnet18',
pretrained=True,
n_stages=5):
super(Backbone, self).__init__()
expand = 1
strides = (2, 1, 2, 1, 1) if arch == 'resnet18':
self.resnet = resnet.resnet18(
pretrained=pretrained,
strides=strides,
norm_layer=get_norm_layer()) elif arch == 'resnet34':
self.resnet = resnet.resnet34(
pretrained=pretrained,
strides=strides,
norm_layer=get_norm_layer()) else: raise ValueError
self.n_stages = n_stages if self.n_stages == 5:
itm_ch = 512 * expand elif self.n_stages == 4:
itm_ch = 256 * expand elif self.n_stages == 3:
itm_ch = 128 * expand else: raise ValueError
self.upsample = nn.Upsample(scale_factor=2)
self.conv_out = Conv3x3(itm_ch, out_ch)
self._trim_resnet() if in_ch != 3:
self.resnet.conv1 = nn.Conv2D(
in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False) if not pretrained:
self.init_weight() def forward(self, x):
y = self.resnet.conv1(x)
y = self.resnet.bn1(y)
y = self.resnet.relu(y)
y = self.resnet.maxpool(y)
y = self.resnet.layer1(y)
y = self.resnet.layer2(y)
y = self.resnet.layer3(y)
y = self.resnet.layer4(y)
y = self.upsample(y) return self.conv_out(y) def _trim_resnet(self):
if self.n_stages > 5: raise ValueError if self.n_stages < 5:
self.resnet.layer4 = Identity() if self.n_stages <= 3:
self.resnet.layer3 = Identity()
self.resnet.avgpool = Identity()
self.resnet.fc = Identity()代码如下
def _get_semantic_tokens(self, x):
b, c = x.shape[:2]
att_map = self.conv_att(x)
att_map = att_map.reshape((b, self.token_len, 1, -1))
att_map = F.softmax(att_map, axis=-1)
x = x.reshape((b, 1, c, -1))
tokens = (x * att_map).sum(-1)
return tokens代码如下
class Residual(nn.Layer):
def __init__(self, fn):
super(Residual, self).__init__()
self.fn = fn def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Layer):
def __init__(self, dim, fn):
super(PreNorm, self).__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout_rate=0.):
super(FeedForward, self).__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, dim), nn.Dropout(dropout_rate))class CrossAttention(nn.Layer):
def __init__(self,
dim,
n_heads=8,
head_dim=64,
dropout_rate=0.,
apply_softmax=True):
super(CrossAttention, self).__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.scale = dim**-0.5
self.apply_softmax = apply_softmax
self.fc_q = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_k = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_v = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_out = nn.Sequential(
nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate)) def forward(self, x, ref):
b, n = x.numpy().shape[:2]
h = self.n_heads
q = self.fc_q(x)
k = self.fc_k(ref)
v = self.fc_v(ref)
q = q.reshape((b, n, h, -1)).transpose((0, 2, 1, 3))
k = k.reshape((b, paddle.shape(ref)[1], h, -1)).transpose((0, 2, 1, 3))
v = v.reshape((b, paddle.shape(ref)[1], h, -1)).transpose((0, 2, 1, 3))
mult = paddle.matmul(q, k, transpose_y=True) * self.scale if self.apply_softmax:
mult = F.softmax(mult, axis=-1)
out = paddle.matmul(mult, v)
out = out.transpose((0, 2, 1, 3)).flatten(2) return self.fc_out(out)class SelfAttention(CrossAttention):
def forward(self, x):
return super(SelfAttention, self).forward(x, x)class TransformerEncoder(nn.Layer):
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate):
super(TransformerEncoder, self).__init__()
self.layers = nn.LayerList([]) for _ in range(depth):
self.layers.append(
nn.LayerList([
Residual(
PreNorm(dim,
SelfAttention(dim, n_heads, head_dim,
dropout_rate))),
Residual(
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
])) def forward(self, x):
for att, ff in self.layers:
x = att(x)
x = ff(x) return x代码如下
class Residual2(nn.Layer):
def __init__(self, fn):
super(Residual2, self).__init__()
self.fn = fn def forward(self, x1, x2, **kwargs):
return self.fn(x1, x2, **kwargs) + x1
class PreNorm2(nn.Layer):
def __init__(self, dim, fn):
super(PreNorm2, self).__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn def forward(self, x1, x2, **kwargs):
return self.fn(self.norm(x1), self.norm(x2), **kwargs)
class TransformerDecoder(nn.Layer):
def __init__(self,
dim,
depth,
n_heads,
head_dim,
mlp_dim,
dropout_rate,
apply_softmax=True):
super(TransformerDecoder, self).__init__()
self.layers = nn.LayerList([]) for _ in range(depth):
self.layers.append(
nn.LayerList([
Residual2(
PreNorm2(dim,
CrossAttention(dim, n_heads, head_dim,
dropout_rate, apply_softmax))),
Residual(
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
])) def forward(self, x, m):
for att, ff in self.layers:
x = att(x, m)
x = ff(x) return x代码如下
y1 = self.decode(x1, token1) y2 = self.decode(x2, token2)# Feature differencingy = paddle.abs(y1 - y2) y = self.upsample(y)# Classifier forwardpred = self.conv_out(y)
在LEVIR-CD的测试集的测试效果如下表,达到验收指标,F1-Score=89.31%
| Network | opt | epoch | batch_size | dataset | F1-Score |
|---|---|---|---|---|---|
| BIT | SGD | 200 | 8 | LEVIR-CD | 89.32% |
!git clone https://github.com/kongdebug/BIT-CD-Paddle.git
正克隆到 'BIT-CD-Paddle'... remote: Enumerating objects: 1004, done. remote: Counting objects: 100% (1004/1004), done. remote: Compressing objects: 100% (779/779), done. remote: Total 1004 (delta 207), reused 976 (delta 194), pack-reused 0 接收对象中: 100% (1004/1004), 28.86 MiB | 6.93 MiB/s, 完成. 处理 delta 中: 100% (207/207), 完成. 检查连接... 完成。
# 解压数据!unzip -qo data/data136610/LEVIR-CD.zip -d data/data136610/
# 安装相应依赖%cd BIT-CD-Paddle/ !pip install -r requirements.txt
# 对数据进行切片处理,注意输入文件夹和输出文件夹不能相同,这一步切割耗时较长!python data/spliter-cd.py --image_folder ../data/data136610/LEVIR-CD --block_size 256 --save_folder ../LEVIR-CD
# 生成模型训练需要的.txt文件!python data/process_levir_data.py --data_dir ../LEVIR-CD
数据集划分已完成。
!python tutorials/train/change_detection/bit_train.py --data_dir=../LEVIR-CD --out_dir=./output/BIT/
!python tutorials/eval/change_detection/bit_eval.py --data_dir=../LEVIR-CD/ --weight_path=../work/BIT/best_model/model.pdparams
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
if data.dtype == np.object:[04-26 00:13:58 MainThread @logger.py:242] Argv: tutorials/eval/change_detection/bit_eval.py --data_dir=../LEVIR-CD/ --weight_path=../work/BIT/best_model/model.pdparams[04-26 00:13:58 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported
2022-04-26 00:13:59 [INFO] 2048 samples in file ../LEVIR-CD/test.txt
W0426 00:13:59.148581 9007 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0426 00:13:59.154322 9007 device_context.cc:465] device: 0, cuDNN Version: 7.6.2022-04-26 00:14:02 [INFO] Loading pretrained model from ../work/BIT/best_model/model.pdparams2022-04-26 00:14:02 [INFO] There are 203/203 variables loaded into BIT.
2022-04-26 00:14:02 [INFO] Start to evaluate(total_samples=2048, total_steps=2048)...
OrderedDict([('miou', 0.8980501731845956), ('category_iou', array([0.98892947, 0.80717087])), ('oacc', 0.9894197657704353), ('category_acc', array([0.99300849, 0.91857525])), ('kappa', 0.887736151189675), ('category_F1-score', array([0.99443393, 0.89329779]))])!python tutorials/predict/change_detection/bit_predict.py --weight_path ../work/BIT/best_model/model.pdparams \ --A ../LEVIR-CD/test/A/test_2_0_0.png --B ../LEVIR-CD/test/B/test_2_0_0.png --pre ../work/pre.png
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations if data.dtype == np.object:[04-26 00:40:19 MainThread @logger.py:242] Argv: tutorials/predict/change_detection/bit_predict.py --weight_path ../work/BIT/best_model/model.pdparams --A ../LEVIR-CD/test/A/test_2_0_0.png --B ../LEVIR-CD/test/B/test_2_0_0.png --pre ../work/pre.png[04-26 00:40:19 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported W0426 00:40:19.621413 10885 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W0426 00:40:19.627353 10885 device_context.cc:465] device: 0, cuDNN Version: 7.6.2022-04-26 00:40:23 [INFO] Loading pretrained model from ../work/BIT/best_model/model.pdparams2022-04-26 00:40:23 [INFO] There are 203/203 variables loaded into BIT. ok finish!
# 展示预测的结果,最后一张为真值import matplotlib.pyplot as pltfrom PIL import Image
T1 = Image.open(r"../LEVIR-CD/test/A/test_2_0_0.png")
T2 = Image.open(r"../LEVIR-CD/test/B/test_2_0_0.png")
GT = Image.open(r"../LEVIR-CD/test/label/test_2_0_0.png")
pred = Image.open(r"../work/pre.png")
plt.figure(figsize=(16, 8))
plt.subplot(1,4,1), plt.title('T1')
plt.imshow(T1), plt.axis('off')
plt.subplot(1,4,2), plt.title('T2')
plt.imshow(T2), plt.axis('off')
plt.subplot(1,4,3), plt.title('pred')
plt.imshow(pred), plt.axis('off')
plt.subplot(1,4,4), plt.title('GT')
plt.imshow(GT), plt.axis('off')
plt.show()<Figure size 1152x576 with 4 Axes>
!python deploy/export/export_model.py --model_dir=../work/BIT/best_model/ \
--save_dir=./inference_model/ --fixed_input_shape=[1,3,256,256]/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations if data.dtype == np.object:[04-26 01:00:27 MainThread @logger.py:242] Argv: deploy/export/export_model.py --model_dir=../work/BIT/best_model/ --save_dir=./inference_model/ --fixed_input_shape=[1,3,256,256][04-26 01:00:27 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported W0426 01:00:28.290376 12365 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W0426 01:00:28.296492 12365 device_context.cc:465] device: 0, cuDNN Version: 7.6. 2022-04-26 01:00:32 [INFO] Model[BIT] loaded. 2022-04-26 01:00:36 [INFO] The model for the inference deployment is saved in ./inference_model/.
该部分依赖auto_log,需要进行安装,安装方式如下:
auto_log的详细介绍参考https://github.com/LDOUBLEV/AutoLog。
!git clone https://github.com/LDOUBLEV/AutoLog !pip3 install -r requirements.txt !python3 setup.py bdist_wheel !pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
!bash ./test_tipc/prepare.sh test_tipc/configs/BIT/train_infer_python.txt 'lite_train_lite_infer'
!bash test_tipc/test_train_inference_python.sh test_tipc/configs/BIT/train_infer_python.txt 'lite_train_lite_infer'
以上就是【第六期论文复现赛-变化检测】BIT的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号