本技术涉及计算机计算领域,尤其涉及一种分类模型训练方法、装置、终端设备以及存储介质。
背景技术:
1、在图像分类模型的应用过程中,首先需要将图像分类模型在训练图像数据集上进行训练,通过自定义的损失函数和优化算法学习迭代最优的神经网络模型参数,然后在验证达到指标后,部署到实际应用中进行分类。
2、在许多实际应用中,图像分类模型需要进行分类的图像数据集常常呈现出长尾分布,即某些类别(头部类)的样本数量远大于其他类别(尾部类),而传统的图像分类模型训练在处理这类长尾分布数据时,由于图像训练集中头部类和尾部类样本数目差距很大,导致在长尾分布的图像数据集上训练的图像分类模型偏向头部类,对尾部类样本的识别准确率很低,尾部类别的分类性能不佳。
3、综上,如何在长尾分布的图像训练集中训练出不偏向头部类、识别各类样本的能力相同的模型,俨然已成为本领域亟需解决的技术问题。
技术实现思路
1、本技术的主要目的在于提供一种分类模型训练方法、装置、终端设备以及计算机可读存储介质,在长尾分布的图像训练集中训练出不偏向头部类、识别各类样本的能力相同的模型。
2、为实现上述目的,本技术提供一种分类模型训练方法,所述分类模型训练方法包括:
3、将长尾分布的图像数据集作为图像分类模型的图像训练集,其中,所述图像训练集包含已标注类别的多个图像训练样本;
4、获取所述图像分类模型中各图像分类器各自的历史累积正梯度和历史累积负梯度,其中,针对每一个所述图像分类器,所述历史累积正梯度为与所述图像分类器类别一致的历史图像训练样本在所述图像分类器上产生的梯度的累加值,所述历史累积负梯度为与当前图像训练样本类别一致,且与所述图像分类器类别不一致的历史图像训练样本在所述图像分类器上产生的梯度的累加值;
5、基于各所述图像分类器的模型参数、所述历史累积正梯度和所述历史累积负梯度,计算所述当前图像训练样本在各所述图像分类器上的各预测值;
6、基于各所述预测值与所述当前图像训练样本的真实标签之间的交叉熵损失更新各所述模型参数,以进行图像分类模型的训练。
7、可选地,所述基于各所述图像分类器的模型参数、所述历史累积正梯度和所述历史累积负梯度,计算所述当前图像训练样本在各所述图像分类器上的各预测值的步骤,包括:
8、针对各所述图像分类器,将所述图像分类器的模型参数与所述当前图像训练样本的特征相乘,并加上所述历史累积正梯度,和,减去所述历史累积负梯度,计算得到所述当前图像训练样本在所述图像分类器上的预测值。
9、可选地,所述获取所述图像分类模型中各图像分类器各自的历史累积正梯度和历史累积负梯度的步骤,包括:
10、针对各所述图像分类器,在所述图像分类模型的历史训练轮次中,聚合所述图像分类器在每一所述历史训练轮次中的历史正梯度得到历史累积正梯度,聚合所述图像分类器在每一所述历史训练轮次中的历史负梯度得到历史累积负梯度。
11、可选地,所述基于各所述预测值与所述当前图像训练样本的真实标签之间的交叉熵损失更新各所述模型参数的步骤包括:
12、基于各所述预测值与所述当前图像训练样本的真实标签之间的交叉熵损失建立所述当前图像训练样本的损失函数;
13、依次将各所述模型参数作为求偏导的分母对所述损失函数进行求偏导计算,得到所述当前图像训练样本在各所述图像分类器上产生的各梯度;
14、基于各所述梯度更新各所述模型参数。
15、可选地,所述基于各所述梯度更新各所述模型参数的步骤,包括:
16、针对每个所述图像分类器,将所述梯度、预设的学习率和所述模型参数输入至预设的优化算法中,得到优化后的所述模型参数,以基于优化后的所述模型参数进行下一训练轮次的预测值计算。
17、可选地,在所述依次将各所述模型参数作为求偏导的分母对所述损失函数进行求偏导计算,得到所述当前图像训练样本在各所述图像分类器上产生的各梯度的步骤之后,所述方法还包括:
18、针对每个所述图像分类器,确定所述梯度为正梯度还是负梯度;
19、若所述梯度为正梯度,则将所述梯度累积至所述历史累积正梯度中,得到更新后的所述历史累积正梯度,以基于更新后的所述历史累积正梯度进行下一训练轮次的预测值计算。
20、可选地,在所述针对每个所述图像分类器,确定所述梯度为正梯度还是负梯度的步骤之后,所述方法还包括:
21、若所述梯度为负梯度,则将所述梯度累积至所述历史累积负梯度,得到更新后的所述历史累积负梯度,以基于更新后的所述历史累积负梯度进行下一训练轮次的预测值计算。
22、此外,为实现上述目的,本技术还提供一种分类模型训练装置,所述分类模型训练装置包括:
23、图像训练集确定模块,用于将长尾分布的图像数据集作为图像分类模型的图像训练集,其中,所述图像训练集包含已标注类别的多个图像训练样本;
24、梯度获取模块,用于获取所述图像分类模型中各图像分类器各自的历史累积正梯度和历史累积负梯度,其中,针对每一个所述图像分类器,所述历史累积正梯度为与所述图像分类器类别一致的历史图像训练样本在所述图像分类器上产生的梯度的累加值,所述历史累积负梯度为与当前图像训练样本类别一致,且与所述图像分类器类别不一致的历史图像训练样本在所述图像分类器上产生的梯度的累加值;
25、预测值计算模块,用于基于各所述图像分类器的模型参数、所述历史累积正梯度和所述历史累积负梯度,计算所述当前图像训练样本在各所述图像分类器上的各预测值;
26、训练模块,用于基于各所述预测值与所述当前图像训练样本的真实标签之间的交叉熵损失更新各所述模型参数,以进行图像分类模型的训练。
27、此外,为实现上述目的,本技术还提供一种终端设备,所述终端设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的图像分类模型训练程序,所述图像分类模型训练程序被所述处理器执行时实现如上所述的分类模型训练方法的步骤。
28、此外,为实现上述目的,本技术还提出一种存储介质,所述存储介质为计算机可读存储介质,所述存储介质上存储有图像分类模型训练程序,所述图像分类模型训练程序被处理器执行时实现如上所述的分类模型训练方法的步骤。
29、本技术实施例提出的一种分类模型训练方法、装置、终端设备以及存储介质,该分类模型训练方法包括:将长尾分布的图像数据集作为图像分类模型的图像训练集,其中,所述图像训练集包含已标注类别的多个图像训练样本;获取所述图像分类模型中各图像分类器各自的历史累积正梯度和历史累积负梯度,其中,针对每一个所述图像分类器,所述历史累积正梯度为与所述图像分类器类别一致的历史图像训练样本在所述图像分类器上产生的梯度的累加值,所述历史累积负梯度为与当前图像训练样本类别一致,且与所述图像分类器类别不一致的历史图像训练样本在所述图像分类器上产生的梯度的累加值;基于各所述图像分类器的模型参数、所述历史累积正梯度和所述历史累积负梯度,计算所述当前图像训练样本在各所述图像分类器上的各预测值;基于各所述预测值与所述当前图像训练样本的真实标签之间的交叉熵损失更新各所述模型参数,以进行图像分类模型的训练。
30、相比于传统的分类模型训练方法,本技术在长尾分布的图像数据集上训练模型时通过获取模型中各图像分类器的历史累积正梯度和历史累积负梯度,并结合各图像分类器的模型参数计算预测值,能够通过图像分类器累积的梯度对本次训练轮次中回传的梯度进行平衡,保证了图像分类模型中尾部类图像分类器的正负梯度平衡和各类图像分类器间产生的梯度平衡,从而实现了在长尾分布的图像训练集中训练出不偏向头部类、识别各类样本的能力相同的模型。
1.一种分类模型训练方法,其特征在于,所述分类模型训练方法包括以下步骤:
2.如权利要求1所述的分类模型训练方法,其特征在于,所述基于各所述图像分类器的模型参数、所述历史累积正梯度和所述历史累积负梯度,计算所述当前图像训练样本在各所述图像分类器上的各预测值的步骤,包括:
3.如权利要求1所述的分类模型训练方法,其特征在于,所述获取所述图像分类模型中各图像分类器各自的历史累积正梯度和历史累积负梯度的步骤,包括:
4.如权利要求1所述的分类模型训练方法,其特征在于,所述基于各所述预测值与所述当前图像训练样本的真实标签之间的交叉熵损失更新各所述模型参数的步骤包括:
5.如权利要求4所述的分类模型训练方法,其特征在于,所述基于各所述梯度更新各所述模型参数的步骤,包括:
6.如权利要求4所述的分类模型训练方法,其特征在于,在所述依次将各所述模型参数作为求偏导的分母对所述损失函数进行求偏导计算,得到所述当前图像训练样本在各所述图像分类器上产生的各梯度的步骤之后,所述方法还包括:
7.如权利要求6所述的分类模型训练方法,其特征在于,在所述针对每个所述图像分类器,确定所述梯度为正梯度还是负梯度的步骤之后,所述方法还包括:
8.一种分类模型训练装置,其特征在于,所述分类模型训练装置包括:
9.一种终端设备,其特征在于,所述终端设备包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的图像分类模型训练程序,所述图像分类模型训练程序被所述处理器执行时实现如权利要求1至7中任一项所述的分类模型训练方法的步骤。
10.一种存储介质,其特征在于,所述存储介质为计算机可读存储介质,所述存储介质上存储有图像分类模型训练程序,所述图像分类模型训练程序被处理器执行时实现如权利要求1至7中任一项所述的分类模型训练方法的步骤。