可以使用PyTorch中的Early Stopping来实现在验证集精确度达到一定阈值时停止训练进入测试阶段。具体实现步骤如下:

  1. 定义EarlyStopping类,包含以下初始化参数:
    • patience:指定在验证集上连续多少个epoch无法提高性能时停止训练。
    • delta:指定当性能提高超过delta时才认为是有显著进步。
    • mode:指定性能衡量标准,如“max”表示最大化,即精度越高越好。
    • verbose:指定是否打印详细信息。
  2. 在每个epoch结束时,计算模型在验证集上的性能,并将性能值与之前的最佳性能进行比较。
  3. 如果当前性能值比最佳性能值高,就更新最佳性能值和模型参数,同时将连续性能下降的次数重置为0。
  4. 如果当前性能值没有提高,则将连续性能下降的次数加1。
  5. 如果连续性能下降的次数超过了patience,则认为模型已经达到最优,停止训练。
  6. 在训练结束后,使用最佳模型对测试集进行评估。

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import torch

class EarlyStopping:
def __init__(self, patience=10, delta=0, mode='max', verbose=True):
self.patience = patience
self.delta = delta
self.mode = mode
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
if self.mode == 'min':
self.val_score = np.Inf
else:
self.val_score = -np.Inf

def __call__(self, epoch_score, model):
if self.mode == 'min':
score = -1.0 * epoch_score
else:
score = np.copy(epoch_score)
if self.best_score is None:
self.best_score = score
self.save_checkpoint(epoch_score, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(epoch_score, model)
self.counter = 0

def save_checkpoint(self, epoch_score, model):
if self.verbose:
print(f'Validation score improved ({self.val_score:.6f} --> {epoch_score:.6f}). Saving model...')
torch.save(model.state_dict(), 'checkpoint.pt')
self.val_score = epoch_score

使用EarlyStopping的示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from tqdm import tqdm

early_stopping = EarlyStopping(patience=10, verbose=True)

for epoch in tqdm(range(num_epochs)):
# 训练模型
train_loss, train_acc = train(model, train_loader, optimizer, criterion)

# 在验证集上评估模型
valid_loss, valid_acc = evaluate(model, valid_loader, criterion)

# 记录性能指标
train_losses.append(train_loss)
train_accs.append(train_acc)
valid_losses.append(valid_loss)
valid_accs.append(valid_acc)

# 检查是否需要提前停止
early_stopping(valid_acc, model)
if early_stopping.early_stop:
print("Early stopping")
break

# 使用最佳模型在测试集上评估
model.load_state_dict(torch.load('checkpoint.pt'))
test_loss, test_acc = evaluate(model, test_loader, criterion)
print(f"Test Loss: {test_loss:.6f}, Test Accuracy: {test_acc:.6f}")