可以使用PyTorch中的Early Stopping来实现在验证集精确度达到一定阈值时停止训练进入测试阶段。具体实现步骤如下:
- 定义EarlyStopping类,包含以下初始化参数:
- patience:指定在验证集上连续多少个epoch无法提高性能时停止训练。
- delta:指定当性能提高超过delta时才认为是有显著进步。
- mode:指定性能衡量标准,如“max”表示最大化,即精度越高越好。
- verbose:指定是否打印详细信息。
- 在每个epoch结束时,计算模型在验证集上的性能,并将性能值与之前的最佳性能进行比较。
- 如果当前性能值比最佳性能值高,就更新最佳性能值和模型参数,同时将连续性能下降的次数重置为0。
- 如果当前性能值没有提高,则将连续性能下降的次数加1。
- 如果连续性能下降的次数超过了patience,则认为模型已经达到最优,停止训练。
- 在训练结束后,使用最佳模型对测试集进行评估。
代码实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| import numpy as np import torch
class EarlyStopping: def __init__(self, patience=10, delta=0, mode='max', verbose=True): self.patience = patience self.delta = delta self.mode = mode self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False if self.mode == 'min': self.val_score = np.Inf else: self.val_score = -np.Inf
def __call__(self, epoch_score, model): if self.mode == 'min': score = -1.0 * epoch_score else: score = np.copy(epoch_score) if self.best_score is None: self.best_score = score self.save_checkpoint(epoch_score, model) elif score < self.best_score + self.delta: self.counter += 1 if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(epoch_score, model) self.counter = 0
def save_checkpoint(self, epoch_score, model): if self.verbose: print(f'Validation score improved ({self.val_score:.6f} --> {epoch_score:.6f}). Saving model...') torch.save(model.state_dict(), 'checkpoint.pt') self.val_score = epoch_score
|
使用EarlyStopping的示例代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| from tqdm import tqdm
early_stopping = EarlyStopping(patience=10, verbose=True)
for epoch in tqdm(range(num_epochs)): train_loss, train_acc = train(model, train_loader, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, valid_loader, criterion)
train_losses.append(train_loss) train_accs.append(train_acc) valid_losses.append(valid_loss) valid_accs.append(valid_acc)
early_stopping(valid_acc, model) if early_stopping.early_stop: print("Early stopping") break
model.load_state_dict(torch.load('checkpoint.pt')) test_loss, test_acc = evaluate(model, test_loader, criterion) print(f"Test Loss: {test_loss:.6f}, Test Accuracy: {test_acc:.6f}")
|