0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

NLP類別不均衡問題之loss合集

jf_78858299 ? 來源:PaperWeekly ? 作者:眼睛里進(jìn)磚頭了 ? 2023-02-23 14:10 ? 次閱讀

NLP 任務(wù)中,數(shù)據(jù)類別不均衡問題應(yīng)該是一個(gè)極常見又頭疼的的問題了。最近在工作中也是碰到這個(gè)問題,花了些時(shí)間梳理并實(shí)踐了下類別不均衡問題的解決方式,主要實(shí)踐了下“魔改”loss(focal loss, GHM loss, dice loss 等),整理了下。所有的 Loss 實(shí)踐代碼在這里:

https://github.com/shuxinyin/NLP-Loss-Pytorch

數(shù)據(jù)不均衡問題也可以說是一個(gè)長(zhǎng)尾問題,但長(zhǎng)尾那部分?jǐn)?shù)據(jù)往往是重要且不能被忽略的,它不僅僅是分類標(biāo)簽下樣本數(shù)量的不平衡,實(shí)質(zhì)上也是 難易樣本的不平衡

解決不均衡問題一般從兩方面入手:

  1. 數(shù)據(jù)層面:重采樣,使得參與迭代計(jì)算的數(shù)據(jù)是均衡的;
  2. 模型層面:重加權(quán),修改模型的 loss,在 loss 計(jì)算上,加大對(duì)少樣本的 loss 獎(jiǎng)勵(lì)。

數(shù)據(jù)層面的重采樣

關(guān)于數(shù)據(jù)層面的重采樣,方式都是通過采樣,重新構(gòu)造數(shù)據(jù)分布,使得數(shù)據(jù)平衡。一般常用的有三種:1)欠采樣;2)過采樣;3)SMOTE。

  1. 欠采樣:指某類別下數(shù)據(jù)較多,則只采取部分?jǐn)?shù)據(jù),直接拋棄一些數(shù)據(jù),這種方式太簡(jiǎn)單粗暴,擬合出來的模型的偏差大,泛化性能較差;
  2. 過采樣:這種方式與欠采樣相反,某類別下數(shù)據(jù)較少,進(jìn)行重復(fù)采樣,達(dá)到數(shù)據(jù)平衡。因?yàn)檫@些少的數(shù)據(jù)反復(fù)迭代計(jì)算,會(huì)使得模型產(chǎn)生過擬合的現(xiàn)象。
  3. SMOTE:一種近鄰插值,可以降低過擬合風(fēng)險(xiǎn),但它是適用于回歸預(yù)測(cè)場(chǎng)景下,而 NLP 任務(wù)一般是離散的情況。

這幾種方法單獨(dú)使用會(huì)或多或少造成數(shù)據(jù)的浪費(fèi)或重,一般會(huì)與 ensemble 方式結(jié)合使用,sample 多份數(shù)據(jù),訓(xùn)練出多個(gè)模型,最后綜合。

但以上幾種方式在工程實(shí)踐中往往是少用的,一是因?yàn)閿?shù)真實(shí)據(jù)珍貴,二也是 ensemble 的方式部署中資源消耗大,沒法接受。因此,就集中看下重加權(quán) loss 改進(jìn)的部分。

模型層面的重加權(quán)

重加權(quán)主要指的是在 loss 計(jì)算階段,通過設(shè)計(jì) loss,調(diào)整類別的權(quán)值對(duì) loss 的貢獻(xiàn)。比較經(jīng)典的 loss 改進(jìn)應(yīng)該是 Focal Loss, GHM Loss, Dice Loss。

2.1 Focal Loss

Focal Loss 是一種解決不平衡問題的經(jīng)典 loss,基本思想就是把注意力集中于那些預(yù)測(cè)不準(zhǔn)的樣本上。

何為預(yù)測(cè)不準(zhǔn)的樣本?比如正樣本的預(yù)測(cè)值小于 0.5 的,或者負(fù)樣本的預(yù)測(cè)值大于 0.5 的樣本。再簡(jiǎn)單點(diǎn),就是當(dāng)正樣本預(yù)測(cè)值>0.5 時(shí),在計(jì)算該樣本的 loss 時(shí),給它一個(gè)小的權(quán)值,反之,正樣本預(yù)測(cè)值<0.5 時(shí),給它一個(gè)大的權(quán)值。同理,對(duì)負(fù)樣本時(shí)也是如此。

以二分類為例,一般采用交叉熵作為模型損失。

圖片

其中 是真實(shí)標(biāo)簽, 是預(yù)測(cè)值,在此基礎(chǔ)又出來了一個(gè)權(quán)重交叉熵,即用一個(gè)超參去緩解上述這種影響,也就是下式。

圖片

接下來,看下 Focal Loss 是怎么做到集中關(guān)注預(yù)測(cè)不準(zhǔn)的樣本?

在交叉熵 loss 基礎(chǔ)上,當(dāng)正樣本預(yù)測(cè)值 大于 0.5 時(shí),需要給它的 loss 一個(gè)小的權(quán)重值 ,使其對(duì)總 loss 影響小,反之正樣本預(yù)測(cè)值 小于 0.5,給它的 loss 一個(gè)大的權(quán)重值。為滿足以上要求,則 增大時(shí), 應(yīng)減小,故剛好 可滿足上述要求。

應(yīng)此加上注意參數(shù) ,得到 Focal Loss 的二分類情況:

圖片

加上調(diào)節(jié)系數(shù) ,F(xiàn)ocal Loss 推廣到多分類的情況:

圖片

其中 為第 t 類預(yù)測(cè)值,,試驗(yàn)中效果最佳時(shí),。

代碼的實(shí)現(xiàn)也是比較簡(jiǎn)潔的。

def __init__(self, num_class, alpha=None, gamma=2, reduction='mean'):
    super(MultiFocalLoss, self).__init__()
    self.gamma = gamma
    ......

    def forward(self, logit, target):
        alpha = self.alpha.to(logit.device)
        prob = F.softmax(logit, dim=1)

        ori_shp = target.shape
        target = target.view(-1, 1)

        prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
        logpt = torch.log(prob)

        alpha_weight = alpha[target.squeeze().long()]
        loss = -alpha_weight * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt

        if self.reduction == 'mean':
            loss = loss.mean()

        return loss

2.2 GHM Loss

上面的 Focal Loss 注重了對(duì) hard example 的學(xué)習(xí),但不是所有的 hard example 都值得關(guān)注,有一些 hard example 很可能是離群點(diǎn),這種離群點(diǎn)當(dāng)然是不應(yīng)該讓模型關(guān)注的。

GHM (gradient harmonizing mechanism) 是一種梯度調(diào)和機(jī)制,GHM Loss 的改進(jìn)思想有兩點(diǎn):1)就是在使模型繼續(xù)保持對(duì) hard example 關(guān)注的基礎(chǔ)上,使模型不去關(guān)注這些離群樣本;2)另外 Focal Loss 中, 的值分別由實(shí)驗(yàn)經(jīng)驗(yàn)得出,而一般情況下超參 是互相影響的,應(yīng)當(dāng)共同進(jìn)行實(shí)驗(yàn)得到。

Focal Loss 中通過調(diào)節(jié)置信度 ,當(dāng)正樣本中模型的預(yù)測(cè)值 較小時(shí),則乘上(1-p),給一個(gè)大的 loss 值使得模型關(guān)注這種樣本。 **于是 GHM Loss 在此基礎(chǔ)上,規(guī)定了一個(gè)置信度范圍 ** **,具體一點(diǎn),就是當(dāng)正樣本中模型的預(yù)測(cè)值為 **** 較小時(shí),要看這個(gè) ** ** 多小,若是 ** ,這種樣本可能就是離群點(diǎn),就不注意它了。

于是 GHM Loss 首先規(guī)定了一個(gè)梯度模長(zhǎng)

圖片

其中, 是模型預(yù)測(cè)概率值, 是 ground-truth 的標(biāo)簽值,這里以二分類為例,取值為 0 或 1??砂l(fā)現(xiàn),** 表示檢測(cè)的難易程度, 越大則檢測(cè)難度越大。**

GHM Loss 的思想是,不要關(guān)注那些容易學(xué)的樣本,也不要關(guān)注那些離群點(diǎn)特別難分的樣本。所以問題就轉(zhuǎn)為我們需要尋找一個(gè)變量去衡量這個(gè)樣本是不是這兩種, 這個(gè)變量需滿足當(dāng) 值大時(shí),它要小,從而進(jìn)行抑制,當(dāng) 值小時(shí),它也要小,進(jìn)行抑制。 于是文中就引入了梯度密度:

表明了樣本 1~N 中,梯度模長(zhǎng)分布在 范圍內(nèi)的樣本個(gè)數(shù), 代表了 區(qū)間的長(zhǎng)度,因此梯度密度 GD(g) 的物理含義是:?jiǎn)挝惶荻饶iL(zhǎng) 部分的樣本個(gè)數(shù)。

在此基礎(chǔ)上,還需要一個(gè)前提,那就是處于 值小與大的樣本(也就是易分樣本與難分樣本)的數(shù)量遠(yuǎn)多于中間值樣本,此時(shí) GD 才可以滿足上述變量的要求。

圖片

此時(shí),對(duì)于每個(gè)樣本,把交叉熵 CE×該樣本梯度密度的倒數(shù),就得到 GHM Loss。

圖片

這里附上邏輯的代碼,完整的可以上文章首尾倉(cāng)庫(kù)查看。

class GHM_Loss(nn.Module):
    def __init__(self, bins, alpha):
        super(GHM_Loss, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None

    def _g2bin(self, g):
        # split to n bins
        return torch.floor(g * (self._bins - 0.0001)).long()


    def forward(self, x, target):
        # compute value g
        g = torch.abs(self._custom_loss_grad(x, target)).detach()

        bin_idx = self._g2bin(g)

        bin_count = torch.zeros((self._bins))
        for i in range(self._bins):
            # 計(jì)算落入bins的梯度模長(zhǎng)數(shù)量
            bin_count[i] = (bin_idx == i).sum().item()

        N = (x.size(0) * x.size(1))

        if self._last_bin_count is None:
            self._last_bin_count = bin_count
        else:
            bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count
            self._last_bin_count = bin_count

        nonempty_bins = (bin_count > 0).sum().item()

        gd = bin_count * nonempty_bins
        gd = torch.clamp(gd, min=0.0001)
        beta = N / gd  # 計(jì)算好樣本的gd值

        # 借由binary_cross_entropy_with_logits,gd值當(dāng)作參數(shù)傳入
        return F.binary_cross_entropy_with_logits(x, target, weight=beta[bin_idx])

2.3 Dice Loss & DSC Loss

Dice Loss 是來自文章 V-Net 提出的,DSC Loss 是香儂科技的 Dice Loss for Data-imbalanced NLP Tasks。

按照上面的邏輯,看一下 Dice Loss 是怎么演變過來的。Dice Loss 主要來自于 dice coefficient,dice coefficient 是一種用于評(píng)估兩個(gè)樣本的相似性的度量函數(shù)。

定義是這樣的:取值范圍在 0 到 1 之間,值越大表示越相似。若令 X 是所有模型預(yù)測(cè)為正的樣本的集合,Y 為所有實(shí)際上為正類的樣本集合,dice coefficient 可重寫為:

圖片

同時(shí),結(jié)合 F1 的指標(biāo)計(jì)算公式推一下,可得:

圖片

可以動(dòng)手推一下,就能得到 dice coefficient 是等同 F1 score 的,**因此本質(zhì)上 dice loss 是直接優(yōu)化 F1 指標(biāo)的。 **

上述表達(dá)式是離散的,需要把上述 DSC 表達(dá)式轉(zhuǎn)化為連續(xù)的版本,需要進(jìn)行軟化處理。對(duì)單個(gè)樣本 x,可以直接定義它的 DSC:

圖片

但是當(dāng)樣本為負(fù)樣本時(shí),y1=0,loss 就為 0 了,需要加一個(gè)平滑項(xiàng)。

圖片

上面有說到 dice coefficient 是一種兩個(gè)樣本的相似性的度量函數(shù),上式中,假設(shè)正樣本 p 越大,dice 值越大,說明模型預(yù)測(cè)的越準(zhǔn),則應(yīng)該 loss 值越小,因此 dice loss 的就變成了下式這也就是最終 dice loss 的樣子。

圖片

為了能得到 focal loss 同樣的功能,讓 dice loss 集中關(guān)注預(yù)測(cè)不準(zhǔn)的樣本,可以與 focal loss 一樣加上一個(gè)調(diào)節(jié)系數(shù) ,就得到了香儂提出的適用于 NLP 任務(wù)的自調(diào)節(jié) DSC-Loss。

圖片

弄明白了原理,看下代碼的實(shí)現(xiàn)。

class DSCLoss(torch.nn.Module):

    def __init__(self, alpha: float = 1.0, smooth: float = 1.0, reduction: str = "mean"):
        super().__init__()
        self.alpha = alpha
        self.smooth = smooth
        self.reduction = reduction

    def forward(self, logits, targets):
        probs = torch.softmax(logits, dim=1)
        probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1))

        probs_with_factor = ((1 - probs) ** self.alpha) * probs
        loss = 1 - (2 * probs_with_factor + self.smooth) / (probs_with_factor + 1 + self.smooth)

        if self.reduction == "mean":
            return loss.mean()

總結(jié)

本文主要討論了類別不均衡問題的解決辦法,可分為數(shù)據(jù)層面的重采樣及模型 loss 方面的改進(jìn),如 focal loss, dice loss 等。最后說一下實(shí)踐下來的經(jīng)驗(yàn),由于不同數(shù)據(jù)集的數(shù)據(jù)分布特點(diǎn)各有不同,dice loss 以及 GHM loss 會(huì)出現(xiàn)些抖動(dòng)、不穩(wěn)定的情況。當(dāng)不想挨個(gè)實(shí)踐的時(shí)候,首推 focal loss,dice loss。

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 數(shù)據(jù)
    +關(guān)注

    關(guān)注

    8

    文章

    6759

    瀏覽量

    88616
  • 代碼
    +關(guān)注

    關(guān)注

    30

    文章

    4695

    瀏覽量

    68081
  • nlp
    nlp
    +關(guān)注

    關(guān)注

    1

    文章

    483

    瀏覽量

    21971
收藏 人收藏

    評(píng)論

    相關(guān)推薦

    詳細(xì)介紹數(shù)據(jù)均衡的方法以及運(yùn)用的不同場(chǎng)景

    對(duì)于整個(gè)數(shù)據(jù)建模來看,數(shù)據(jù)均衡算法屬于數(shù)據(jù)預(yù)處理一環(huán)。當(dāng)整個(gè)數(shù)據(jù)集從調(diào)出數(shù)據(jù)庫(kù)到拿到手的時(shí)候,對(duì)于分類數(shù)據(jù)集來說類別一般都是不均衡的,整個(gè)數(shù)據(jù)集合也是較為離散的。因此不可能一拿到數(shù)據(jù)集就可進(jìn)行建模,
    的頭像 發(fā)表于 07-20 09:44 ?5593次閱讀

    NLP的tfidf作詞向量

    NLPtfidf作詞向量
    發(fā)表于 06-01 17:28

    Return Loss Headromm

    Return Loss Headromm return loss is related to the impedance of a cable.To understand return loss we must first d
    發(fā)表于 03-31 09:57 ?15次下載

    調(diào)音臺(tái)信號(hào)處理設(shè)備均衡器和激勵(lì)器詳解

    調(diào)音臺(tái)信號(hào)處理設(shè)備均衡器和激勵(lì)器詳解 均衡器和
    發(fā)表于 04-19 15:07 ?4901次閱讀

    基于差異度的不均衡電信客戶數(shù)據(jù)分類方法

    針對(duì)傳統(tǒng)分類技術(shù)對(duì)不均衡電信客戶數(shù)據(jù)集中流失客戶識(shí)別能力不足的問題,提出一種基于差異度的改進(jìn)型不均衡數(shù)據(jù)分類(IDBC)算法。該算法在基于差異度分類(DBC)算法的基礎(chǔ)上改進(jìn)了原型選擇策略。在原型
    發(fā)表于 12-04 16:36 ?0次下載

    不均衡數(shù)據(jù)集上基于子域?qū)W習(xí)的復(fù)合分類模型

    為進(jìn)一步弱化數(shù)據(jù)不均衡對(duì)分類算法的束縛,從數(shù)據(jù)集區(qū)域分布特性著手,提出了不均衡數(shù)據(jù)集上基于子域?qū)W習(xí)的復(fù)合分類模型。子域劃分階段,擴(kuò)展支持向量數(shù)據(jù)描述( SVDD)算法給出類的最小界定域,劃分
    發(fā)表于 12-12 15:28 ?0次下載

    基于不均衡的加權(quán)在線貫序極限學(xué)習(xí)機(jī)

    針對(duì)現(xiàn)有學(xué)習(xí)算法難以有效提高不均衡在線貫序數(shù)據(jù)中少類樣本分類精度的問題,提出一種基于不均衡樣本重構(gòu)的加權(quán)在線貫序極限學(xué)習(xí)機(jī)。該算法從提取在線貫序數(shù)據(jù)的分布特性入手,主要包括離線和在線兩個(gè)階段:離線
    發(fā)表于 01-09 16:44 ?0次下載

    比亞迪在新能源市場(chǎng)不斷領(lǐng)先,但不同能源車型發(fā)展不均衡問題不可忽視

    毋庸置疑,目前比亞迪在新能源市場(chǎng)獲得領(lǐng)先以及突破,但是如何改變新能源產(chǎn)銷的過度依賴,解決不同能源車型不均衡性的發(fā)展也考驗(yàn)著比亞迪。
    發(fā)表于 09-02 09:33 ?1530次閱讀

    基于目前TDD網(wǎng)絡(luò)高負(fù)荷及FD不均衡現(xiàn)狀分析

    FD功率功率設(shè)置不一致,導(dǎo)致不同小區(qū)覆蓋差異,功率大,覆蓋遠(yuǎn)的小區(qū)吸收用戶多,導(dǎo)致FD同覆蓋小去不均衡。在處理這類問題中首要是功率拉齊。一般同覆蓋情況下要求D比F功率可適當(dāng)高1-3DB。同時(shí)對(duì)新擴(kuò)容共設(shè)備小區(qū)需功率拉齊。
    發(fā)表于 12-25 10:13 ?4082次閱讀
    基于目前TDD網(wǎng)絡(luò)高負(fù)荷及FD<b class='flag-5'>不均衡</b>現(xiàn)狀分析

    基于不均衡醫(yī)學(xué)數(shù)據(jù)集的疾病預(yù)測(cè)模型

    基于不均衡醫(yī)學(xué)數(shù)據(jù)集的疾病預(yù)測(cè)模型
    發(fā)表于 06-15 14:15 ?9次下載

    一種新的不均衡關(guān)聯(lián)分類算法ACI

    基于規(guī)則的分類算法具有分類性能妤、可解釋性強(qiáng)的優(yōu)點(diǎn),得到了廣泛的應(yīng)用。然而已有的基于規(guī)則的分類算法沒有考慮不均衡數(shù)據(jù)的情況,從而影響了其對(duì)不均衡數(shù)據(jù)的分類效果。文中提出了一種新的不均衡關(guān)聯(lián)分類算法
    發(fā)表于 06-17 15:27 ?16次下載

    數(shù)據(jù)類別不均衡問題的分類及解決方式

      數(shù)據(jù)類別不均衡問題應(yīng)該是一個(gè)極常見又頭疼的的問題了。最近在工作中也是碰到這個(gè)問題,花了些時(shí)間梳理并實(shí)踐了類別不均衡問題的解決方式,主要實(shí)踐了“魔改”
    的頭像 發(fā)表于 07-08 14:51 ?3576次閱讀

    Loss計(jì)算詳細(xì)解析

    分類損失(cls_loss):該損失用于判斷模型是否能夠準(zhǔn)確地識(shí)別出圖像中的對(duì)象,并將其分類到正確的類別中。
    的頭像 發(fā)表于 01-13 14:38 ?3173次閱讀
    <b class='flag-5'>Loss</b>計(jì)算詳細(xì)解析

    NLP類別不均衡問題loss大集合

      NLP 任務(wù)中,數(shù)據(jù)類別不均衡問題應(yīng)該是一個(gè)極常見又頭疼的的問題了。最近在工作中也是碰到這個(gè)問題,花了些時(shí)間梳理并實(shí)踐了下類別不均衡問題
    的頭像 發(fā)表于 01-31 16:52 ?756次閱讀

    讓充電更高效:便攜式鋰電池均衡維護(hù)儀的智能

    隨著電動(dòng)汽車和可穿戴設(shè)備的普及,鋰電池成為了現(xiàn)代生活的重要組成部分。但隨之而來的問題是,鋰電池在使用過程中會(huì)出現(xiàn)充放電不均衡的現(xiàn)象,影響其性能和使用壽命。要解決這個(gè)問題,便攜式鋰電池均衡維護(hù)儀
    的頭像 發(fā)表于 07-06 09:58 ?7292次閱讀
    讓充電更高效:便攜式鋰電池<b class='flag-5'>均衡</b>維護(hù)儀的智能<b class='flag-5'>之</b>選