本發明屬于圖像處理,涉及一種基于對比學習的圖像分類方法、設備及存儲介質。
背景技術:
1、圖像數據作為一種豐富的信息載體,在人們的工作、學習和生活中發揮著越來越重要的作用,逐漸成為人們生活的必需品。隨著科學技術的發展,圖像分類技術已經被應用到諸如自動化監控、軍事情報、地理信息處理、生物學歸類等場景中,此外,它還能幫助人們快速準確地理解大量數據,并有效進行決策;
2、但是如何對海量的圖像數據進行分類是目前的技術難題,作為計算機視覺領域的重要研究任務。隨著圖像數據的爆發增長,產生了越來越多的無標記數據,這些數據并不能很好地被應用到需要提前標注好圖像類型和物體位置的圖像分類模型中進行訓練。
3、對比學習,作為無監督學習的一個分支,能夠為下游的圖像分類和目標檢測任務提供良好的圖像特征提取網絡,模型訓練過程中無需使用帶標記的圖像數據。通過樣本或分布間的對比進行學習,其主要目標是學習到一個低維且具有判別性的特征表示。通過樣本間對比,正樣本對相互靠近,負樣本對相互遠離,網絡因此可以學習到更具判別性和代表性的通用特征。
4、雖然對比學習模型能夠使用大量無標記數據訓練通用的特征提取網絡,但其仍然存在一定的問題。當前大多數對比學習模型不加區分地為下游任務提供通用的特征提取網絡,然而不同的下游任務諸如圖像分類和目標檢測任務側重點不同,圖像分類任務僅需對整張圖片進行分類,更關注圖像整體特征。
5、moco?v3是一種近幾年發展迅速的對比學習模型,廣泛用于目標分類任務預訓練階段,相比于傳統的vit模型,大大提高了模型的穩定性。但moco?v3使用的visiontransformer使得網絡整體訓練過程變難,參數量變大,不充分等不利于該模型在下游任務中的應用,還會導致在訓練過程中穩定性不足等問題,原有通過凍結patch?projectionlayer層來解決此問題的方法并不完善。
技術實現思路
1、本發明的目的在于解決現有技術中對比學習在圖像分類任務中訓練時間久、分類準確率低的不足的問題,提供一種基于對比學習的圖像分類方法、設備及存儲介質。
2、為達到上述目的,本發明采用以下技術方案予以實現:
3、本發明公開了一種基于對比學習的圖像分類方法,具體包括如下步驟:
4、s1、對同一數據進行兩種不同的數據增強處理后,通過stride小于kernel?size的卷積函數劃分圖像樣本,得到兩個包含圖像相鄰信息的特征序列;
5、s2、將步驟s1得到的兩個特征序列作為不同的分支,輸入到sswt(scalable?shiftwindow?transformer)雙分支編碼器中進行局部特征和全局特征的提取,得到sswt雙分支編碼器的輸出特征;
6、s3、將sswt雙分支編碼器的輸出特征通過基于全連接層的predictor網絡結構,得到結果;
7、s4、基于多特征構建目標識別對比損失函數對兩個分支得到的一維特征計算損失;
8、作為優選,步驟s1具體包括如下子步驟:
9、對于h×w×c的輸入數據x,分別對其進行兩種不同的數據增強處理得到數據x1和x2,分別將x1和x2輸入stride小于kernel?size的卷積函數劃分成2個大小為4×4的特征序列,數據特征變成
10、作為優選,所述數據增強處理包括翻轉、旋轉、縮放、隨機裁剪中的一種或多種。
11、作為優選,步驟s2中所述4stage的sswt特征提取網絡由stage1、stage2、stage3和stage4,所述stage1由linear?embedding模塊、2層sswt?block特征提取模塊組成;所述stage2由down?sample模塊、linear?embedding模塊、2層sswt?block特征提取模塊、基于attention的特征融合模塊confuse組成;所述stage3由down?sample模塊、linearembedding模塊、6層sswt?block特征提取模塊、基于attention的特征融合模塊confuse組成;所述stage4由down?sample模塊、linear?embedding模塊、2層sswt?block特征提取模塊、基于attention的特征融合模塊confuse組成;
12、其中,linear?embedding模塊用于改變輸入特征的通道數;所述sswt?block特征提取模塊用于提取局部特征和全局特征;所述down?sample模塊用于將相鄰的四個特征塊進行拼接;所述attention的特征融合模塊confuse用于將特征進行融合;
13、步驟s2中,特征序列進入4stage的sswt特征提取網絡后,依次通過stage1、stage2、stage3和stage4進行處理。
14、作為優選,所述sswt?block特征提取模塊由左注意力模塊和右注意力模塊構成;所述左注意力模塊由固定窗口的可擴展多頭自注意力網絡和通道組自注意力網絡組成;所述右注意力模塊由滑動窗口的可擴展多頭自注意力網絡和通道組自注意力網絡組成;
15、所述sswt?block特征提取模塊包括如下操作:
16、s21:輸入特征首先進入左注意力模塊,經過layer?norm層進行歸一化后,輸入固定窗口的可擴展多頭自注意力網絡,得到特征z′;
17、s22、將特征zl-1與經過layer?norm層和固定窗口的可擴展多頭自注意力網絡后的特征z′進行殘差求和得到特征作為通道組自注意力網絡的輸入特征;
18、s23、將經過通道組自注意力網絡、layer?norm層和mlp模塊后的特征與特征進行殘差求和得到特征zl并作為右注意力模塊的輸入特征;
19、s24、輸入特征zl進入右注意力模塊,經過layer?norm層進行歸一化后,輸入滑動窗口的可擴展多頭自注意力網絡,得到特征z′+1;
20、s25、將特征zl與經過layer?norm和固定窗口的可擴展多頭自注意力網絡后的特征z′進行殘差求和得到特征作為通道組自注意力網絡的輸入特征;
21、s26、將經過通道組自注意力網絡、layer?norm層和mlp模塊后的特征與特征進行殘差求和得到特征zl+1。
22、作為優選,步驟s21中,固定窗口的可擴展多頭自注意力網絡的處理如下:將輸入特征的空間維度和通道維度根據縮放因子進行縮放,得到若干個窗口,在每個窗口內計算query、key和value矩陣并得到同一窗口內特征塊之間的注意力矩陣。
23、作為優選,所述通道組自注意力網絡的處理如下:將輸入特征首先輸入到layernorm層進行歸一化,然后通過projection層將其按通道維度劃分成若干個通道組,并在每個通道組中計算注意力關系。
24、作為優選,特征序列進入stage2、stage3、stage4的處理過程如下:
25、a21、首先通過down?sample層合并特征序列中相鄰的四個特征,得到合并后的特征序列;
26、a22、再步驟a21中將合并后的特征序列通過linear?embedding層,將通道數減少為原來的一半;
27、a23、將步驟a22中linear?embedding層輸出的特征序列作為sswt?block模塊的輸入計算特征序列在空間維度和通道維度的注意力關系,得到輸出特征;
28、a24、將步驟a23中得到的輸出特征與步驟a21中將合并后的特征序列作為attention的特征融合模塊confuse的輸入特征;
29、a25、attention的特征融合模塊confuse中,首先通過3個3x3的卷積網絡求出輸入特征對應的query、key和value,然后將query和key的轉置矩陣相乘并進行softmax運算,得到相應的注意力權重矩陣attention?map,最后將求得的attention?map和value相乘并進行3x3的卷積運算得到最終輸出的特征。
30、作為優選,所述步驟s4包括以下步驟:
31、在預訓練階段,首先為來自于同一張圖像的兩個隨機crop進行不同的數據增強,將它們分別輸入編碼器為fθ和fξ的上下分支中,得到輸出向量q和k,學習的目的是將向量q當作query查詢同一批次的數據中對應的正樣本key,為了最小化對比損失,本發明使用info?nce損失函數:
32、
33、其中k+是編碼器fξ輸出的與q來自同一張圖像的crop,作為q的正樣本向量,而k-是同一個批次中來自于其它圖像的crop經過編碼器fξ輸出的特征向量,作為q的負樣本,而τ是為了對特征向量進行歸一化而設定的一個超參數。
34、
35、
36、在測試階段的圖像分類任務中,分類損失使用交叉熵損失函數cross?entropyloss,其中c代表數據集的類別數,p(·)、q(·)函數分別指真實標簽和預測標簽。
37、本發明還公開了一種基于對比學習的圖像分類方法,包括圖像劃分模塊、sswt編碼器、基于attention的特征融合模塊以及對比損失模塊;
38、圖像劃分模塊,用于圖片數據增強和patches劃分和獲取;
39、sswt編碼器,包含多階段過程,其過程包括sswt?block特征提取模塊和基于attention的特征融合模塊。該編碼器作用在于使得模型提取的特征更具多樣性,使計算量開銷更小,提取的特征更加豐富。
40、對比損失模塊,該模塊包括線性投影層和損失函數層,該模塊作用在于構建分類任務損失函數,計算損失。
41、本發明還公開了一種終端設備,包括存儲器、處理器以及存儲在所述存儲器中并可在所述處理器上運行的計算機程序,所述處理器執行所述計算機程序時實現本發明上述述方法的步驟。
42、本發明還公開了一種計算機可讀存儲介質,所述計算機可讀存儲介質存儲有計算機程序,所述計算機程序被處理器執行時實現上述述方法的步驟。
43、與現有技術相比,本發明具有以下有益效果:
44、本發明公開了一種基于對比學習的圖像分類方法,首先將同一圖片經過不同的數據增強處理劃分為兩個類別分別送入上下兩個分支。在每個分支經過image?split層劃分成大小為4×4的patches序列,以充分挖掘圖像特征。
45、然后經過四個stage,在第一個stage中首先通過linear?embedding改變輸入特征塊的通道維度,然后將patches作為sswt?block的輸入。除sswt?block模塊的堆疊次數之外,之后的三個stage網絡結構完全相同,每個stage中首先通過down?sample層將patches序列中相鄰的四個特征塊拼接起來,然后將數量縮減為原來四分之一的patches依次經過linear?embedding層和sswt?block層。該模塊通過兩個連續的注意力模塊分別提取了圖像的局部特征和全局特征。然后經過下采樣后的特征和sswt?block輸出的特征作為基于attention的confuse模塊進行特征融合,該模塊同時保留了圖像的深層特征和淺層特征。最后經由全連接輸出,基于上下兩分支設計了面向圖像分類的對比損失函數,優化了分類標簽的預測輸出結果。
46、進一步的,本發明中,sswt?block特征提取模塊sswt?block特征提取模塊共包括左右兩個連續的注意力模塊,其中第一個模塊由固定窗口的可擴展多頭自注意力網絡(w-smsa)和通道組自注意力網絡(cgsa)組成,類似地,第二個模塊由滑動窗口的可擴展多頭自注意力網絡(sw-smsa)和cgsa組成。w-smsa計算每個窗口內patches間的注意力關系,sw-smsa通過滑動窗口計算不同窗口間patches間的注意力關系。w-smsa模塊通過使用基于縮放因子rn和rc的可擴展多頭自注意力機制,通過fq(·)、fk(·)和fv(·)將空間維度n和通道維度c分別縮放為n×rn和c×rc,通過將rn和rc引入到空間維度和通道維度,使產生的attention矩陣與輸入patches的維度解綁以解決同一張圖片中有許多非必需的同源信息以及傳統注意力機制中矩陣維度與輸入特征維度高度耦合的問題,提升模型的學習能力。cgsa模塊按通道維度劃分通道組,并在每個組中計算注意力關系,以此捕捉全局的信息交互和特征表示。sswt?block模塊通過兩個連續的注意力模塊分別提取了圖像的局部特征和全局特征,網絡的多次殘差求和結構同時保留了圖像的深層特征和淺層特征。
47、進一步的,本發明中,基于attention的特征融合模塊首先將輸入的特征經過downsample層合并特征圖中相鄰的四個patches得到特征向量y1,為了減少后續sswt?block模塊的計算量,通過linear?embedding層將y1通道數減少為原來的一半得到特征向量y2,然后將其作為sswt?block模塊的輸入計算patches序列在空間維度和通道維度的注意力關系得到輸出特征y3,隨后將得到的y1和y3作為特征融合層confuse的輸入,confuse層的輸出特征融合了down?sample和sswt?block兩個階段的特征。最后將結果通過基于全連接層的predictor網絡結構得到最終結果,此網絡結構僅包含linear層和relu激活函數。該特征融合模塊有助于提升模型在圖像分類上的性能。