nnUNetV2修改笔记-D1

nnUNetv2修改笔记

D1、今天开始修改nnUNetv2,以下是需要记录的东西。

从云龙那获得代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
plain_conv_encoder_EAB.py
import torch
from torch import nn
import numpy as np
from typing import Union, Type, List, Tuple

from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op, get_matching_convtransp
from dynamic_network_architectures.building_blocks.all_attention import *


class PlainConvEncoder(nn.Module):
has_shown_prompt = [False] # 类属性,用于记录是否已经显示过提示
def __init__(self,
input_channels: int,
n_stages: int,
features_per_stage: Union[int, List[int], Tuple[int, ...]],
conv_op: Type[_ConvNd],
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
strides: Union[int, List[int], Tuple[int, ...]],
n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
return_skips: bool = False,
nonlin_first: bool = False,
pool: str = 'conv'
):

super().__init__()
if isinstance(kernel_sizes, int):
kernel_sizes = [kernel_sizes] * n_stages
if isinstance(features_per_stage, int):
features_per_stage = [features_per_stage] * n_stages
if isinstance(n_conv_per_stage, int):
n_conv_per_stage = [n_conv_per_stage] * n_stages
if isinstance(strides, int):
strides = [strides] * n_stages
assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"
assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"
assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"
assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \
"Important: first entry is recommended to be 1, else we run strided conv drectly on the input"

stages = []
for s in range(n_stages):
stage_modules = []
if pool == 'max' or pool == 'avg':
if (isinstance(strides[s], int) and strides[s] != 1) or \
isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):
stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))
conv_stride = 1
elif pool == 'conv':
conv_stride = strides[s]
else:
raise RuntimeError()
stage_modules.append(StackedConvBlocks(
n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
))
stages.append(nn.Sequential(*stage_modules))
input_channels = features_per_stage[s]

self.stages = nn.Sequential(*stages)
self.output_channels = features_per_stage
self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]
self.return_skips = return_skips

# we store some things that a potential decoder needs
self.conv_op = conv_op
self.norm_op = norm_op
self.norm_op_kwargs = norm_op_kwargs
self.nonlin = nonlin
self.nonlin_kwargs = nonlin_kwargs
self.dropout_op = dropout_op
self.dropout_op_kwargs = dropout_op_kwargs
self.conv_bias = conv_bias
self.kernel_sizes = kernel_sizes

################################################## 边缘聚合块 ##################################################
transpconv_op = get_matching_convtransp(conv_op=self.conv_op)
self.downblock_channal = [32, 64, 128, 256, 512, 512]
self.mattn = Spartial_Attention3d(kernel_size=3)
self.mdcat1 = nn.Sequential(
StackedConvBlocks(
1, conv_op, self.downblock_channal[0], self.downblock_channal[0], kernel_sizes[s], 2,
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first))
self.mdcat2 = nn.Sequential(
StackedConvBlocks(
1, conv_op, self.downblock_channal[0] + self.downblock_channal[1], self.downblock_channal[0] + self.downblock_channal[1], kernel_sizes[s], 2,
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first))
self.mupcat3 = nn.Sequential(
StackedConvBlocks(
1, conv_op, self.downblock_channal[0] + self.downblock_channal[1] + self.downblock_channal[2], self.downblock_channal[2], kernel_sizes[s], 1,
conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first))
self.gate3 = Gate(self.downblock_channal[2], self.downblock_channal[2])
self.mupcat2 = transpconv_op(self.downblock_channal[0] + self.downblock_channal[1] + self.downblock_channal[2],
self.downblock_channal[1], kernel_size=2, stride=2, bias=False)
self.gate2 = Gate(in_channels=self.downblock_channal[1], out_channels=self.downblock_channal[1])
self.mupcat1 = transpconv_op(self.downblock_channal[0] + self.downblock_channal[1] + self.downblock_channal[2],
self.downblock_channal[0], kernel_size=4, stride=4, bias=False)
self.gate1 = Gate(in_channels=self.downblock_channal[0], out_channels=self.downblock_channal[0])
################################################## 边缘聚合块 ##################################################

def forward(self, x):
ret = []
for s in self.stages:
x = s(x)
ret.append(x)
if not PlainConvEncoder.has_shown_prompt[0]: # 如果还未显示过提示
print("################################################## EAB ##################################################")
PlainConvEncoder.has_shown_prompt[0] = True # 将提示标记为已显示
# middle attention
m1 = self.mattn(ret[0])
m2 = self.mattn(ret[1])
m3 = self.mattn(ret[2])

m1m2 = torch.cat([self.mdcat1(m1), m2], dim=1) # Shape : [B, C=32+64, D/2, H/2, W/2]
m_feature = torch.cat([self.mdcat2(m1m2), m3], dim=1) # Shape : [B, C=32+64+128, D/4, H/4, W/4]

ret[0] = self.gate1(self.mupcat1(m_feature), ret[0])
ret[1] = self.gate2(self.mupcat2(m_feature), ret[1])
ret[2] = self.gate3(self.mupcat3(m_feature), ret[2])

'''
tensors = {'m1': m1, 'm2': m2, 'm3': m3, 'm1m2': m1m2, 'm_feature': m_feature, 'gate_output1': gate_output1, 'gate_output2': gate_output2, 'self.mupcat3(m_feature)': self.mupcat3(m_feature)}
for name, tensor in tensors.items():
print(f"Name: {name}, Shape: {tensor.shape}")
'''

if self.return_skips:
return ret
else:
return ret[-1]

def compute_conv_feature_map_size(self, input_size):
output = np.int64(0)
for s in range(len(self.stages)):
if isinstance(self.stages[s], nn.Sequential):
for sq in self.stages[s]:
if hasattr(sq, 'compute_conv_feature_map_size'):
output += self.stages[s][-1].compute_conv_feature_map_size(input_size)
else:
output += self.stages[s].compute_conv_feature_map_size(input_size)
input_size = [i // j for i, j in zip(input_size, self.strides[s])]
return output
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
all_attention.py

import torch
import torch.nn as nn


class Spartial_Attention3d(nn.Module):

def __init__(self, kernel_size):
super(Spartial_Attention3d, self).__init__()

assert kernel_size % 2 == 1, "kernel_size = {}".format(kernel_size)
padding = (kernel_size - 1) // 2

self.__layer = nn.Sequential(
nn.Conv3d(2, 1, kernel_size=kernel_size, padding=padding),
nn.Sigmoid(),
)

def forward(self, x):
avg_mask = torch.mean(x, dim=1, keepdim=True)
max_mask, _ = torch.max(x, dim=1, keepdim=True)
mask = torch.cat([avg_mask, max_mask], dim=1)

mask = self.__layer(mask)
return x * mask


class Channel_Attention3d(nn.Module):

def __init__(self, channel, r):
super(Channel_Attention3d, self).__init__()

self.__avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.__max_pool = nn.AdaptiveMaxPool3d((1, 1, 1))

self.__fc = nn.Sequential(
nn.Conv3d(channel, channel // r, 1, bias=False),
nn.LeakyReLU(True),
nn.Conv3d(channel // r, channel, 1, bias=False),
)
self.__sigmoid = nn.Sigmoid()

def forward(self, x):
y1 = self.__avg_pool(x)
y1 = self.__fc(y1)

y2 = self.__max_pool(x)
y2 = self.__fc(y2)

y = self.__sigmoid(y1 + y2)
return x * y


class Gate(nn.Module):

def __init__(self, in_channels, out_channels):
super(Gate, self).__init__()
self._w = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True),
nn.InstanceNorm3d(out_channels)
)
self.relu = nn.LeakyReLU(inplace=True)
self.psi = nn.Sequential(
nn.Conv3d(out_channels, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.InstanceNorm3d(1),
nn.Sigmoid()
)

def forward(self, x1, x2):
w1 = self._w(x1)
w2 = self._w(x2)
psi = self.relu(w1 + w2)
psi = self.psi(psi)
return x2 * psi


class FocalModulation(nn.Module):
""" Focal Modulation

Args:
dim (int): Number of input channels.
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
focal_level (int): Number of focal levels
focal_window (int): Focal window size at focal level 1
focal_factor (int, default=2): Step to increase the focal window
use_postln (bool, default=False): Whether use post-modulation layernorm
"""

def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False):

super().__init__()
self.dim = dim

# specific args for focalv3
self.focal_level = focal_level
self.focal_window = focal_window
self.focal_factor = focal_factor
self.use_postln = use_postln

self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True)
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)

self.act = nn.GELU()
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.focal_layers = nn.ModuleList()

if self.use_postln:
self.ln = nn.LayerNorm(dim)

for k in range(self.focal_level):
kernel_size = self.focal_factor * k + self.focal_window
self.focal_layers.append(
nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
padding=kernel_size // 2, bias=False),
nn.GELU(),
)
)

def forward(self, x):
""" Forward function.

Args:
x: input features with shape of (B, H, W, C)
"""
B, nH, nW, C = x.shape
print(x.shape)
x = self.f(x)
print(x.shape)
x = x.permute(0, 3, 1, 2).contiguous()
print(x.shape)
q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1)
print(q.shape, ctx.shape, gates.shape)

ctx_all = 0
for l in range(self.focal_level):
ctx = self.focal_layers[l](ctx)
ctx_all = ctx_all + ctx * gates[:, l:l + 1]
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]

x_out = q * self.h(ctx_all)
x_out = x_out.permute(0, 2, 3, 1).contiguous()
if self.use_postln:
x_out = self.ln(x_out)
x_out = self.proj(x_out)
x_out = self.proj_drop(x_out)
return x_out

这两个文件。正在训练看看有没有很好的效果。

nnUNetv2修改epoch:

nnunetV1:在下面路径的py文件中,修改self.max_num_epochs

1
anaconda3/envs/wzh/lib/python3.8/site-packages/nnunet/training/network_training/nnUNetTrainerV2.py

nnunetV2:在下面路径的py文件中,修改self.num_epochs

1
anaconda3/envs/wzh/lib/python3.8/site-packages/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py