网站首页 > 技术文章 正文
前言
因为最近跑VIT的实验,所以有用到timm的一些配置,在mixup的实现里面发现labelsmooth的实现是按照最基本的方法来的,与很多pytorch的实现略有不同,所以简单做了一个推导。
一、交叉熵损失(CrossEntropyLoss)
先简单讲一下交叉熵损失,也是我们做分类任务里面最常用的一种损失,公式如下:
这里的表示的是模型输出的logits后经过softmax的结果,shape为,表示的是对应的label,经常用onehot来表示,pytorch版本可以使用scalar表示,shape为,这里表示为batchsize,表示为向量长度。
可以简单拆解为如下:
- log_softmax
这个很简单,就是做softmax后取对数,公式如下:
- NLLloss
这个玩意的全程叫做negative log-likelihood(负对数似然损失),简单解释下,假设需要求解一个分布, 由于未知其表达式,所以先定义一个分布,通过来使得靠近的分布。这里采用最大似然估计来进行求解,,不断的更新参数使得 来自的样本在中的概率越来越高。但是有个问题,连乘对于求导不友好,计算也过于复杂,所以可以对其取对数,有
最大化对数似然函数就等效于最小化负对数似然函数,所以加个负号,公式如下:
由于求loss的时候,采用的是onehot形式,除去当前类别为1其余都为0,所以有:
这个形式就和交叉熵形式一致,所以NLLLoss也叫CrossEntropyLoss。
二、LabelSmooth
由于Softmax会存在一个问题,就是Over Confidence,会使得模型对于弱项的照顾很少。LabelSmooth的作用就是为了降低Softmax所带来的的高Confidence的影响,让模型略微关注到低概率分布的权重。这样做也会有点影响,最终预测的时候,模型输出的置信度会稍微低一些,需要比较细致的阈值过滤。
假设,表示对标签进行平滑的数值,那么就有
这里表示类别数量,表示当前的类别,带有labelsmooth的CELoss就变成了:
相比原始的CELoss,LabelSmoothCELoss则是每一项都会参与到loss计算。
三、公式推导
# labelsmooth
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
def forward(self, x, target):
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
可以看到这个code的实现和公式有点出入,第一部分是self.confidence * nll_loss, 第二部分是self.smoothing * smooth_loss。我们将其展开为:
假设为target,那么对于onehot来说除了
以外均为0,所以有:
进一步有组合项:
最后可以写成矩阵点乘的形式:
我们表示为LabelSmooth后的标签,和第二节中的设定对齐,所以得到的Loss就是原本的表达式:
与之对应的timm中的mixup部分的LabelSmoothCELoss代码如下:
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
return y1 * lam + y2 * (1. - lam)
四,总结
LabelSmooth可以用来标签平滑,从公示推导方面来讲,也可以充当正则的作用,尤其是针对难分类别的情况下,效果会表现更好一些。
- 上一篇: 观察矩阵的推导(你怎么看矩阵)
- 下一篇: 可以用爱因斯坦求和替代的那些矩阵运算
猜你喜欢
- 2025-07-08 最让人惊艳的并行算法之一——矩阵算法
- 2025-07-08 小红书基于 PPO 的多模态大模型 RLHF 系统的设计与优化
- 2025-07-08 洞察宇宙(十七):线性代数期末复习指南
- 2025-07-08 线性代数:矩阵的定义和简单计算。
- 2025-07-08 OpenAI久违发了篇「正经」论文:线性布局实现高效张量计算
- 2025-07-08 全网最详细!油管1小时视频详解AlphaTensor矩阵乘法算法
- 2025-07-08 初等矩阵简单性质(初等矩阵的性质及例子)
- 2025-07-08 越览(44)——Matlab入门学习(3)之矩阵运算
- 2025-07-08 可以用爱因斯坦求和替代的那些矩阵运算
- 2025-07-08 观察矩阵的推导(你怎么看矩阵)
- 1509℃桌面软件开发新体验!用 Blazor Hybrid 打造简洁高效的视频处理工具
- 532℃Dify工具使用全场景:dify-sandbox沙盒的原理(源码篇·第2期)
- 495℃MySQL service启动脚本浅析(r12笔记第59天)
- 474℃服务器异常重启,导致mysql启动失败,问题解决过程记录
- 471℃启用MySQL查询缓存(mysql8.0查询缓存)
- 452℃「赵强老师」MySQL的闪回(赵强iso是哪个大学毕业的)
- 431℃mysql服务怎么启动和关闭?(mysql服务怎么启动和关闭)
- 428℃MySQL server PID file could not be found!失败
- 最近发表
- 标签列表
-
- c++中::是什么意思 (83)
- 标签用于 (65)
- 主键只能有一个吗 (66)
- c#console.writeline不显示 (75)
- pythoncase语句 (81)
- es6includes (73)
- windowsscripthost (67)
- apt-getinstall-y (86)
- node_modules怎么生成 (76)
- chromepost (65)
- c++int转char (75)
- static函数和普通函数 (76)
- el-date-picker开始日期早于结束日期 (70)
- js判断是否是json字符串 (67)
- checkout-b (67)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- & (66)
- java (73)
- js数组插入 (83)
- linux删除一个文件夹 (65)
- mac安装java (72)
- eacces (67)
- 查看mysql是否启动 (70)
- 无效的列索引 (74)