ELTV
very begining of ELTV
The loop of Deep learning is divided into 4,just summarize the tricks and make a cheatsheet
Extract data
针对pytorch的话,一般要实现一个dataset类,里面重写两个函数就可以了,一个是__len__
,另外一个是__getitem__
支持从0到len(self)索引所有的整数。实际代码看一下之前对VOC数据集写的dataset类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, mosaic, train):
super(YoloDataset, self).__init__()
self.annotation_lines = annotation_lines
self.input_shape = input_shape
self.num_classes = num_classes
self.length = len(self.annotation_lines)
self.mosaic = mosaic
self.train = train
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.length
#---------------------------------------------------#
# 训练时进行数据的随机增强
# 验证时不进行数据的随机增强
#---------------------------------------------------#
if self.mosaic:
if self.rand() < 0.5:
lines = sample(self.annotation_lines, 3)
lines.append(self.annotation_lines[index])
shuffle(lines)
image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
else:
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
else:
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
box = np.array(box, dtype=np.float32)
if len(box) != 0:
box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
return image, box
Load the data
load数据的话直接用dataloader就可以了,具体可以看doc
1
gen = DataLoader(train_dataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,drop_last=True, collate_fn=yolo_dataset_collate)
Train the model
常用的训练手法一般是冻结一部分,训练一个头部网络,最后再解冻前面的backbone,最后整个拉着一起SGD
Validation model
具体的评测标准不一样,可能需要看具体怎么做评测
This post is licensed under CC BY 4.0 by the author.