本发明涉及人工智能,其特别涉及一种大模型蒸馏方法、装置和存储介质。
背景技术:
1、随着深度学习技术的不断发展,自然语言处理领域所应用的模型规模也在不断扩大。巨大的参数量为大模型提供了前所未有的性能,使其在少样本的情况下也能提供良好的服务;但同时,模型规模也对计算提出了极大挑战,如今,大规模语言模型往往具有千亿以上的参数量,由于大模型具有计算密集的特性,如何将其部署到真实的应用场景中,仍是不可忽视的难题。
2、当前,为了规避大模型的部署挑战,研究人员和企业通常选择部署较小的专用模型。这些较小规模的模型可来源于蒸馏的训练范式。蒸馏旨在使用较大的语言模型生成的标签来训练表现较差的较小模型,较大语言模型具有较强的知识涌现能力,其可以生成大量高质量标签的数据,以训练小模型。尽管该训练范式可以有效压缩使用模型的规模,但仍然面对很多挑战,一方面需要大量的训练数据来提升小模型的效果,同时需要有效激发和高效迁移大模型知识的方法,以及需要有效的学生模型训练机制,这给现有工作带来了困难。
技术实现思路
1、为了解决上述问题,本发明提供一种大模型蒸馏方法、装置和存储介质。
2、本发明为解决上述技术问题,提供如下的技术方案:一种大模型蒸馏方法,包括以下步骤:
3、步骤s1,构建教师模型和初始学生模型;
4、步骤s2,获取初始指令进行指令复杂化操作,得到复杂化指令;
5、步骤s3,将复杂化指令输入教师模型生成答案,基于复杂化指令和答案构建指令-答案训练数据;
6、步骤s4,使用指令-答案训练数据以数据蒸馏方法训练初始学生模型。
7、优选地,所述步骤s2中的复杂化操作具体包括以下步骤:
8、步骤s21,基于预训练的复杂化模型将初始指令扩充为初级复杂化指令;
9、步骤s22,将初级复杂化指令输入学生模型得到初级答案,基于预训练的反馈模型对初级答案打分;
10、步骤s23,判断打分分数是否低于预设阈值,若是,将对应的初级复杂化指令作为复杂化指令。
11、优选地,所述步骤s21具体包括以下步骤:
12、步骤s211,构造基于few-shot学习的指令模板;
13、步骤s212,预训练的复杂化模型基于few-shot学习的指令模板,将初始指令扩充为初级复杂化指令。
14、优选地,所述步骤s4具体包括以下步骤:
15、步骤s41,将复杂化指令分别输入教师模型和学生模型,得到教师模型输出的第一预测概率和学生模型输出的第二预测概率;
16、步骤s42,根据第一预测概率和第二预测概率基于kl散度得到kl散度损失函数;
17、步骤s43,根据第二预测概率基于对数损失得到一般训练损失函数;
18、步骤s44,基于kl散度损失函数和一般训练损失函数得到总损失函数;
19、步骤s45,基于总损失函数迭代更新当前的学生模型。
20、优选地,所述步骤s42中的kl散度损失函数具体为:
21、
22、其中表示输出的答案,表示输出的答案中第k个单词,表示教师模型输出的的预测概率;表示学生模型输出的的预测概率。
23、优选地,所述步骤s43中的一般训练损失函数具体为:
24、
25、其中表示输出的答案,表示输出的答案中第k个单词,表示学生模型输出的的预测概率,表示概率,表示前面出现的单词,表示依据前面出现的单词预测下一个单词为的概率。
26、优选地,所述步骤s44中基于kl散度损失函数和一般训练损失函数得到总损失函数具体为:
27、步骤s441;赋予kl散度损失函数权重参数α,赋予一般训练损失函数权重参数1-α,其中α取值范围从0到1;
28、步骤s442,将kl散度损失函数和一般训练损失函数进行加权求和得到总损失函数。
29、优选地,所述步骤s22中反馈模型的预训练具体包括以下步骤:
30、步骤s221,构建初始反馈模型;
31、步骤s222,将初始指令输入教师模型得到正例并将初始指令输入初始学生模型得到反例;
32、步骤s223,利用正反例训练初始反馈模型得到反馈模型。
33、本发明为解决上述技术问题,提供又一技术方案如下:一种大模型蒸馏装置,用于实施上述任意一项大模型蒸馏方法,大模型蒸馏装置包括以下模块:
34、初始模块,用于构建教师模型和初始学生模型;
35、指令复杂化模块,用于获取初始指令进行指令复杂化操作,得到复杂化指令;
36、指令-答案训练数据生成模块,用于将复杂化指令输入教师模型生成答案,基于复杂化指令和答案构建指令-答案训练数据;
37、训练模块,用于使用指令-答案训练数据以数据蒸馏方法训练初始学生模型。
38、本发明为解决上述技术问题,提供又一技术方案如下:一种计算机可读存储介质,计算机程序被执行时实现上述任意一项所述的大模型蒸馏方法。
39、与现有技术相比,本发明所提供的一种大模型蒸馏方法、装置和存储介质,具有如下的有益效果:
40、1、本发明实施例中提供一种大模型蒸馏方法,通过构建教师模型和初始学生模型;获取初始指令进行指令复杂化操作,得到复杂化指令;将复杂化指令输入教师模型生成答案,基于复杂化指令和答案构建高质量的指令-答案训练数据;使用高质量的指令-答案训练数据以数据蒸馏方法训练初始学生模型。通过以一种通用的指令复杂化方法获取初始指令进行指令复杂化操作,从而得到复杂化指令以便有效激活教师模型的知识,能够从初始指令逻辑复杂性和知识覆盖面等不同角度对于初始指令进行复杂化处理,提高复杂化指令的复杂性与丰富度,相较于简单指令,复杂化指令更加高质量,包含更复杂的知识,能够有效激活教师模型的知识并有效提升学生模型的知识,学生模型能够学习更加复杂的任务,学习过程更加高效,输入学生模型的复杂化指令超越其能力范围,有效提升训练效率;同时利用面向教师模型的知识蒸馏机制,采用数据蒸馏的方法训练,数据蒸馏是指通过数据作为桥梁将教师模型的知识迁移至学生模型,与已有的数据蒸馏相比,本发明的数据蒸馏通过获取复杂化指令及答案,相较于简单指令,指令-答案训练数据更加高质量,包含更复杂的知识,使用高质量的指令-答案训练数据以数据蒸馏方法训练初始学生模型一方面可以有效激活教师模型的知识,教师模型输出的高知识密度的答案充分包含了教师模型的知识;另一方面可以有效提升学生模型的知识,学习更加复杂的任务。获取高质量带有教师模型知识的训练数据训练学生模型,传递教师模型的知识等关键信息给学生模型,可以高效将大规模参数教师模型的知识蒸馏到小规模参数学生模型,能够减小训练数据规模,降低训练推理计算成本,有效激发和高效迁移教师模型的高水平知识,使得学生模型训练机制更为有效,训练所得到的学生模型具有更优的性能,从而起到事半功倍的训练效果。
41、2、本发明实施例中提供的复杂化操作具体包括基于预训练的复杂化模型将初始指令扩充为初级复杂化指令;将初级复杂化指令输入学生模型得到初级答案,基于预训练的反馈模型对初级答案打分;判断打分分数是否低于预设阈值,若是,将对应的初级复杂化指令作为复杂化指令。在训练过程之中学生模型的能力会随着训练被不断增强,学生模型所接受的训练难度随之不断减小,通过构建一种基于初级复杂化指令复杂度检测和增强学生模型性能的反馈学习方法,反馈模型对学生模型输出的初级答案进行打分,判断打分分数是否低于预设阈值,低于预设阈值分数意味着学生模型还没有很好学习该初级复杂化指令的相关知识,将低于预设阈值分数对应的初级复杂化指令作为复杂化指令,能够动态调整指导复杂化指令的更新,确保复杂化指令的复杂性与丰富度超出学生模型的能力水平。由于仅简单利用训练数据不能够最大化学生模型性能,采用基于初级复杂化指令复杂度检测和增强学生模型性能的反馈学习方法,如果当前学生模型在某些指令或任务上性能不好,后续基于反馈模型打分,选择性增强这部分性能,不断改进学生模型的薄弱能力。由于在复杂问题上教师模型和学生模型的表现差异将达到最大,反馈模型的实时评价反馈能够确保以满足复杂度要求的复杂化指令训练学生模型,以较高效率缩小学生模型与教师模型在复杂问题上的表现差异,使得学生模型可以更好学习教师模型的知识,针对多种指令或任务达到整体性能最优。
42、3、本发明实施例中的步骤s21具体包括以下步骤:构造基于few-shot学习的指令模板;预训练的复杂化模型基于few-shot学习的指令模板,将初始指令扩充为初级复杂化指令。few-shot学习即少样本学习,是一种机器学习范式,旨在使模型能够在少量样本的情况下完成学习任务。通过构造基于few-shot学习的指令模板,预训练的复杂化模型基于few-shot学习的指令模板,将初始指令扩充为初级复杂化指令,从而只需要准备少量高质量的指令复杂化学习示例,便可以引导大模型完成指令复杂化过程。能够进一步减少训练数据的使用数量,通过基于few-shot学习的指令模板,复杂化模型可以获得出色的泛化能力,生成语言清晰简洁、语句合乎礼仪、无违规信息及无歧义的初级复杂化指令,从而避免产生无意义的内容,仅在长度上扩充指令。
43、4、本发明实施例中提供的步骤s4具体包括:将复杂化指令分别输入教师模型和学生模型,得到教师模型输出的第一预测概率和学生模型输出的第二预测概率;根据第一预测概率和第二预测概率基于kl散度得到kl散度损失函数;根据第二预测概率基于对数损失得到一般训练损失函数;基于kl散度损失函数和一般训练损失函数得到总损失函数,基于总损失函数迭代更新当前的学生模型。通过指令-答案训练数据以数据蒸馏方法训练初始学生模型,指令-答案训练数据中通过复杂化指令输入教师模型得到的答案由于是教师模型生成从而包括了教师模型的知识,通过复杂化指令可以更好激发教师模型的能力,获取高质量带有知识的训练数据,从而更有利于提升学生模型的性能。通过将复杂化指令分别输入教师模型和学生模型,基于预测概率分布,通过基于kl散度和对数损失构建总损失函数对学生模型进行优化,提升学生模型训练效率。
44、5、本发明实施例中界定具体的kl散度损失函数,采用基于kl散度的训练目标,用来量化教师模型与学生模型在输出概率分布上的差异,用以增强知识蒸馏过程。通过根据kl散度损失函数来最小化kl散度,学生模型将不断学习教师模型掌握的知识与能力,缩小与教师模型之间的差距,实现针对教师模型的知识蒸馏。
45、6、本发明实施例中界定具体的一般训练损失函数,基于最大似然估计的思想,训练目标是最大化相应的似然函数,具有较好的概率分布表征性能,促使学生模型更加关注分类正确的复杂化指令训练样本,提升学生模型的训练效率。
46、7、本发明实施例中提供的基于kl散度损失函数和一般训练损失函数得到总损失函数具体为:赋予kl散度损失函数权重参数α,赋予一般训练损失函数权重参数1-α,其中α取值范围从0到1;将kl散度损失函数和一般训练损失函数进行加权求和得到总损失函数。通过赋予kl散度损失函数和一般训练损失函数权重参数,并进行加权求和得到总损失函数,能够根据实际情况调整权重参数对于总损失函数进行调整。
47、8、本发明实施例中提供的反馈模型的预训练具体包括以下步骤:构建初始反馈模型;将初始指令输入教师模型得到正例并将初始指令输入初始学生模型得到反例;利用正反例训练初始反馈模型得到反馈模型。通过将初始指令分别输入教师模型和初始学生模型构建正反例,利用正反例训练反馈模型,使得训练得到的反馈模型可以有效指导指令复杂化操作对当前初级复杂化指令进行更新,确保以满足复杂度要求的复杂化指令训练学生模型。
48、9、本发明实施例还提供一种大模型蒸馏装置和计算机可读存储介质,具有与上述一种大模型蒸馏方法相同的有益效果,在此不做赘述。
1.一种大模型蒸馏方法,其特征在于,包括以下步骤:
2.如权利要求1所述的大模型蒸馏方法,其特征在于:所述步骤s2中的复杂化操作具体包括以下步骤:
3.如权利要求2所述的大模型蒸馏方法,其特征在于:所述步骤s21具体包括以下步骤:
4.如权利要求1所述的大模型蒸馏方法,其特征在于:所述步骤s4具体包括以下步骤:
5.如权利要求4所述的大模型蒸馏方法,其特征在于:所述步骤s42中的kl散度损失函数具体为:
6.如权利要求4所述的大模型蒸馏方法,其特征在于:所述步骤s43中的一般训练损失函数具体为:
7.如权利要求4所述的大模型蒸馏方法,其特征在于:所述步骤s44中基于kl散度损失函数和一般训练损失函数得到总损失函数具体为:
8.如权利要求2所述的大模型蒸馏方法,其特征在于:所述步骤s22中反馈模型的预训练具体包括以下步骤:
9.一种大模型蒸馏装置,用于实施权利要求1~8任意一项所述的大模型蒸馏方法,其特征在于,所述大模型蒸馏装置包括以下模块:
10.一种计算机可读存储介质,计算机可读存储介质存储有计算机程序,其特征在于:所述计算机程序被处理器执行时实现如权利要求1-8任一项所述的大模型蒸馏方法。
