圖神經(jīng)網(wǎng)絡的知識提取與超越:一個有效的知識蒸餾框架
2021-05-10 10:52:03AI云資訊850
下文將針對這一框架進行詳細說明。
論文鏈接: https://arxiv.org/pdf/2103.02885.pdf 論文代碼: https://github.com/BUPT-GAMMA/CPF一、引言
隨著深度學習的成功,基于圖神經(jīng)網(wǎng)絡(GNN)的方法[8,12,30]已經(jīng)證明了它們在分類節(jié)點標簽方面的有效性。大多數(shù)GNN模型采用消息傳遞策略[7]:每個節(jié)點從其鄰域聚合特征,然后將具有非線性激活的分層映射函數(shù)應用于聚合信息。這樣,GNN可以在其模型中利用圖結(jié)構(gòu)和節(jié)點特征信息。
然而,這些神經(jīng)模型的預測缺乏透明性,人們難以理解[36],而這對于與安全和道德相關的關鍵決策應用至關重要[5]。此外,圖拓撲、節(jié)點特征和映射矩陣的耦合導致復雜的預測機制,無法充分利用數(shù)據(jù)中的先驗知識。例如,已有研究表明,標簽傳播法采用上述同質(zhì)性假設來表示的基于結(jié)構(gòu)的先驗,在圖卷積網(wǎng)絡(GCN)[12]中沒有充分使用[15,31]。
作為證據(jù),最近的研究提出通過添加正則化[31]或操縱圖過濾器[15,25]將標簽傳播機制納入GCN。他們的實驗結(jié)果表明,通過強調(diào)這種基于結(jié)構(gòu)的先驗知識可以改善GCN。然而,這些方法具有三個主要缺點:
1. 其模型的主體仍然是GNN,并阻止它們進行更可解釋的預測;
2. 它們是單一模型而不是框架,因此與其他高級GNN架構(gòu)不兼容;
3. 他們忽略了另一個重要的先驗知識,即基于特征的先驗知識,這意味著節(jié)點的標簽完全由其自身的特征確定。
為了解決這些問題,我們提出了一個有效的知識蒸餾框架,以將任意預訓練的GNN教師模型的知識注入精心設計的學生模型中。學生模型是通過兩個簡單的預測機制構(gòu)建的,即標簽傳播和特征轉(zhuǎn)換,它們自然分別保留了基于結(jié)構(gòu)和基于特征的先驗知識。
具體來說,我們將學生模型設計為參數(shù)化標簽傳播和基于特征的2層感知機(MLP)的可訓練組合。另一方面,已有研究表明,教師模型的知識在于其軟預測[9]。通過模擬教師模型預測的軟標簽,我們的學生模型能夠進一步利用預訓練的GNN中的知識。因此,學習的學生模型具有更可解釋的預測過程,并且可以利用GNN和基于結(jié)構(gòu)/特征的先驗知識。我們的框架概述如圖1所示。
圖1:我們的知識蒸餾框架的示意圖。學生模型的兩種簡單預測機制可確保充分利用基于結(jié)構(gòu)/功能的先驗知識。在知識蒸餾過程中,將提取GNN教師中的知識并將其注入學生。因此,學生可以超越其相應的老師,得到更有效和可解釋的預測。
我們在五個公共基準數(shù)據(jù)集上進行了實驗,并采用了幾種流行的GNN模型,包括GCN[12]、GAT[30]、SAGE[8]、APPNP[13]、SGC[33]和最新的深層GCN模型GCNII[4]作為教師模型。
實驗結(jié)果 表明,就分類精度而言,學生模型的表現(xiàn)優(yōu)于其相應的教師模型1.4%-4.7%。值得注意的是,我們也將框架應用于GLP[15],它通過操縱圖過濾器來統(tǒng)一GCN和標簽傳播。結(jié)果,我們?nèi)匀豢梢垣@得1.5%-2.3%的相對改進,這 表明了我們框架的潛在兼容性。此外,我們通過探究參數(shù)化標簽傳播與特征轉(zhuǎn)換之間的可學習平衡參數(shù)以及標簽傳播中每個節(jié)點的可學習置信度得分,來研究學生模型的可解釋性??偠灾?,改進是一致,并且更重要的是,它具有更好的可解釋性。
本文的貢獻總結(jié)如下:
我們提出了一個有效的知識蒸餾框架,以提取任意預訓練的GNN模型的知識,并將其注入學生模型,以實現(xiàn)更有效和可解釋的預測。
我們將學生模型設計為參數(shù)化標簽傳播和基于特征的兩層MLP的可訓練組合。因此,學生模型有一個更可解釋的預測過程,并自然地保留了基于結(jié)構(gòu)/特征的先驗。因此,學習的學生模型可以同時利用GNN和先驗知識。
五個基準數(shù)據(jù)集和七個GNN教師模型上的實驗結(jié)果表明了我們的框架有效性。對學生模型中學習權(quán)重的廣泛研究也說明了我們方法的可解釋性。
二、方法
在本節(jié)中,我們 將從形式化半監(jiān)督節(jié)點分類問題開始,并介 紹符號。然后,我們將展示我們的知識蒸餾框架,以提取GNN的知識。然后, 我們將提出學生模型的體系結(jié)構(gòu),該模型是參數(shù)化標簽傳播和基于特征的兩層MLP的可訓練組合 。最后,我們將討論學生模型的可解釋性和框架的計算復雜性。
1.半監(jiān)督節(jié)點分類:
我 們首先概述節(jié)點分類問題。給定 一個連通圖 和一個標記點集 ,其中 師節(jié)點集, 是邊集,節(jié)點分類的目標是為每個節(jié)點無標記點集 中的節(jié)點 預測標簽。每個節(jié)點 擁有標簽 ,其中 是所有可能的標簽集合。此外,圖數(shù)據(jù)通常擁有節(jié)點特征 ,并且可以利用特征來提升分類準確率。每行矩陣 的每行 表示節(jié)點 的 維特征向量。
2.知識蒸餾框架:
基于GNN的節(jié)點分類方法往往是一個黑盒,輸入圖結(jié)構(gòu) 、標記點集 和節(jié)點特征 ,輸出分類器 。分類器 將預測無標記點 的標簽為 的概率 ,其中 。對于標記節(jié)點 ,如果 的標簽為 ,那么 ,其余標簽 。簡化起見,我們使用 表示所有標簽的概率分布。
在本文中,我們框架里的教師模型可以使用任意GNN,例如GCN[12]或GAT[30]。我們稱教師模型里的預訓練分類器為 。另一方面,我們使用 表示學生模型, 是參數(shù), 表示學生模型對節(jié)點v的預測概率分布。
在知識蒸餾[9]的框架中,訓練學生模型使其最小化與預訓練教師模型的軟標簽預測,使得教師模型里的潛在知識被提取并注入學生模 型中。因此,優(yōu)化目標是對齊學生模型和與訓練教師模型的輸出,可以形式化為:
其中 度量兩個預測概率分布之間的距離。特別地,本文使用歐氏距離。(注:我們還嘗試最小化KL散度或最大化交叉熵。但是我們發(fā)現(xiàn)歐幾里得距離的效果最好,并且在數(shù)值上更穩(wěn)定。)
3.學生模型架構(gòu):
我 們假設節(jié)點的標簽預測遵循兩種簡單的機制:
1.從其相鄰節(jié)點傳播標簽;
2.從其自身特征進行轉(zhuǎn)換。
因此,如圖2所示,我們將學生模型設計為這兩種機制的組合,即參數(shù)化標簽傳播(PLP)模塊和特征轉(zhuǎn)換(FT)模塊,它們可以自然地分別保留基于結(jié)構(gòu)的先驗知識和基于特征的先驗知識。蒸餾后,學生將通過更易于解釋的預測機制從GNN和先驗知識中受益。
圖2:我們建議的學生模型的架構(gòu)圖。 以中心節(jié)點 為例,學生模型從節(jié)點 的原始特征和統(tǒng)一的標簽分布作為軟標簽開始,然后在每一層,將 的軟標簽預測更新為來自 的鄰居的參數(shù)化標簽傳播(PLP)和 的特征變換(FT)的可訓練組合。最終,將使學生與經(jīng)過訓練的教師的軟標簽預測之間的距離最小化。
在本小節(jié)中,我們將首先簡要回顧傳統(tǒng)的標簽傳播算法。然后,我們將介紹我們的PLP和FT模塊及其可訓練的組合。
3.1 標簽傳播:
標簽傳播(LP)[40]是基于圖的經(jīng)典半監(jiān)督學習模型。該模型僅遵循以下假設:由邊連接(或占據(jù)相同流形)的節(jié)點極有可能共享相同的標簽。基于此假設,標簽將從標記的節(jié)點傳播到未標記的節(jié)點以進行預測。
正式地,我們使用 表示LP的最終預測,使用 表示k輪迭代后的LP預測。在這個工作中,如果 是標記節(jié)點,我們將對節(jié)點 的預測初始化為一個獨熱編碼向量。否則,我們將為每個未標記的節(jié)點 設置均勻分布,這表明所有類的概率在開始時都是相同的。初始化可以形式化為:
其中, 是節(jié)點 在第 次迭代中的預測概率分布。在第k+1次迭代時,LP將按照如下方式更新無標記節(jié)點 的預測:
其中, 時節(jié)點 的鄰居集合, 是控制節(jié)點更新平滑度的超參。
注意LP沒有需要訓練的參數(shù),因此以端到端的方式不能擬合教師模型的輸出。因此 ,我們通過引入更多參數(shù)來提升LP的表達能力。
3.2 參數(shù)化標簽傳播模塊:
現(xiàn)在,我們將通過在LP中進一步參數(shù)化邊緣權(quán)重來介紹我們的參數(shù)化標簽傳播(PLP)模塊。如等式3所示,LP模型在傳播過程中平等對待節(jié)點的所有鄰居。但是,我們假設不同鄰居對一個節(jié)點的重要性應該不同,這決定了節(jié)點之間的傳播強度。更具體地說,我們假設某些節(jié)點的標簽預測比其他節(jié)點更"自信"。例如,一個節(jié)點的預測標簽與其大多數(shù)鄰居相似。這樣的節(jié)點將更有可能將其標簽傳播給鄰居,并使它們保持不變。
形式化來說,我們將給每個節(jié)點v設置一個置信度分數(shù) 。在傳播過程中,所有節(jié)點 的鄰居和 自身將把他們的標簽傳播給 。基于置信值越大,邊緣權(quán)值越大的直覺,我們?yōu)?重寫了等式3中的預測更新函數(shù)如下:
其中 是節(jié)點 和節(jié)點 的邊權(quán),通過下面的 函數(shù)計算:
與LP相似, 按照等式2初始化,在傳播過程中,每個標記點 的 仍然保持獨熱真實編碼向量。
注意,作為可選項,我們可以進一步參數(shù)化置信度分數(shù) 用于歸納設置:
其中, 是一個可學習參數(shù),將節(jié)點 的特征映射為置信度分數(shù)。
3.3 特征轉(zhuǎn)換模塊:
注意 ,通過邊緣傳播標簽的PLP模塊強調(diào)了基于結(jié)構(gòu)的先驗知識。因此,我們還引入了特征變換(FT)模塊作為補充預測機制。FT模塊僅通過查看節(jié)點的原始特征來 預測標簽。形式化來說,用 表示FT模塊的預測,我們使用兩層MLP后接一個softmax函數(shù)來將特征轉(zhuǎn)換為軟標簽預測:
注:雖然單層邏輯回歸更具可解釋性,但我們發(fā)現(xiàn)兩層邏輯回歸對于提高學生的模型能力是必要的。
3.4 可訓練組合:
現(xiàn)在我們 將結(jié)合PLP和FT模塊作為我們的完整學生模型。細節(jié)上,我們 將為每個節(jié)點 學習一個可訓練參數(shù) ,來平衡PLP和FT之間的預測。換句話說,F(xiàn)T和PLP的預測將在每個傳播步驟合并。我們將合并后的完整模型命名為CPF,等式4中的每個無標記節(jié)點 的預測更新公式可以重新寫做:
其中邊權(quán) 和初始化 與PLP模塊一致。根據(jù)是否按照等式6參數(shù)化置信度分數(shù) ,模型有兩個變體,分別是歸納模型CPF-ind和轉(zhuǎn)導模型CPF-tra。
4.整體算法與細節(jié)
假設我們的學生模型一共有K層,等式1中的蒸餾目標可以進一步寫為:
其中, 是 范數(shù),參數(shù)集合 包括PLP和FT之間的平衡參數(shù) ,PLP模塊內(nèi)部的置信度參數(shù) (或歸納設置下的參數(shù) ),以及FT模塊中MLP的參數(shù) 。還有一個重要的超參數(shù):傳播層數(shù) 。
5.對模型可解釋性與計算復雜性的討論
在本小節(jié)中, 我們將討論學習的學生模型的可解釋性和算法的復雜性。
經(jīng)過知識蒸餾后,我們的學生模型CPF會將特定節(jié)點的標簽作為標簽傳播和基于特征的MLP的預測之間的加權(quán)平均值進行預測。平衡參數(shù)指示基于結(jié)構(gòu)的LP還是基于特征的MLP對于節(jié)點 的預測更重要。LP機制幾乎是透明的,我們可以輕松地找出節(jié)點 在每個迭代中受哪個鄰居影響的程度。另一方面,對基于特征的MLP的理解可以通過現(xiàn)有工作[21]或直接查看不同特征的梯度來獲得。因此,學習過的學生模型比GNN教師具有更好的解釋性。
算法每次迭代(算法1的第3行到第13行)的時間復雜度和空間復雜度都是 ,這和數(shù)據(jù)集的規(guī)模線性相關。事實上,操作可以簡單寫成矩陣形式,對于真實數(shù)據(jù)集的訓練過程,使用單GPU可以在幾秒內(nèi)完成。因此,我們提出的知識蒸餾框架的時間、空間效率都很高。
三、實驗
在本節(jié)中,我們 將從介紹實驗中使用的數(shù)據(jù)集和教師模型開始。然后,我們將詳細介紹教師模型和學生變體的實驗設置。之后,我們將給出評估半監(jiān)督 節(jié)點分類的定量結(jié)果。我們還在不同數(shù)量的傳播層和訓練比率下進行實驗,以說明算法的魯棒性。最后,我們將提 供定性案例研究和可視化效果,以更好地理 解我們的學生模型CPF中的學習參數(shù)。
1.數(shù)據(jù)集
表1:數(shù)據(jù)集統(tǒng)計信息
我們使用五個公共基準數(shù)據(jù)集進行實驗,數(shù)據(jù)集的統(tǒng)計數(shù)據(jù)如表1所示。如以前的文獻[14,24,27]所做的那樣,我們僅考慮最大的連通分量,并將邊視為無向邊。
根據(jù)先前工作[24]中的實驗設置,我們從每個類別中隨機抽取20個節(jié)點作為標記節(jié)點,30個用于驗證節(jié)點,所有其他節(jié)點用于測試。
2.教師模型及其設置
為了進行全面比較,我們在我們的知識蒸餾框架中考慮了七個GNN模型作為教師模型;對于每個數(shù)據(jù)集和教師模型,我們測試下列學生變體:
-
PLP: 只考慮參數(shù)化標簽傳播機制的學生變體;
-
FT:只考慮特征轉(zhuǎn)換機制的學生變體;
-
CPF-ind:歸納設置下的完整模型;
-
CPF-tra:轉(zhuǎn)導設置下的完整模型。
表2:GCN[12]和GAT[30]作為教師模型的分類準確率
表3:APPNP[13]和SGAE[8]作為教師模型的分類準確率
表4:SGC[33]和GCNII[4]作為教師模型的分類準確率
表5:GLP[15]作為教師模型的分類準確率
五個數(shù)據(jù)集、七個GNN教師模型、四個學生變體模型上的實驗結(jié)果在表格2,3,4,5中展示。
4.不同傳播層數(shù)的分析
在本小節(jié)中,我們將研究關鍵超參數(shù)對學生模型CPF的體系結(jié)構(gòu)(即傳播層數(shù))的影響。實際上,流行的GNN模型(例如GCN和GAT)對層數(shù)非常敏感。較大數(shù)量的層將導致過平滑的問題,并嚴重損害模型性能。因此,我們在Cora數(shù)據(jù)集上進行了實驗,以進一步分析該超參數(shù)。
圖3:Cora數(shù)據(jù)集上具有不同數(shù)量傳播層的CPF-ind和CPF-tra的分類精度。圖例表示指導學生的老師模式。
5.不同訓練比例的分析
為了進一步證明該框架的有效性,我們在不同的訓練比例下進行了額外的實驗。具體來說,我們以Cora數(shù)據(jù)集為例,將每個類的標記節(jié)點數(shù)量從5個變化到50個。實驗結(jié)果如圖4所示。
圖4:Cora數(shù)據(jù)集上不同數(shù)量的標記節(jié)點下的分類精度。子標題指示相應的教師模型。
6.可解釋性分析
現(xiàn)在,我們 將分析學習的學生模型CPF的可解釋性。具體來說,我們將探究PLP和FT之間的學習平衡參數(shù) 以及每個節(jié)點的置信度得分 。我 們的目標是找出哪種節(jié)點具有最大或最小的 和 。在本小節(jié)中,我們將使用由GCN和GAT教師模型指導的CPF-ind學生模型在Cora數(shù)據(jù)集上進行展示。
圖5:用于可解釋性分析的平衡參數(shù) 案例研究。此處的子標題表示該節(jié)點是按GCN/GAT作為教師模型,按大或小值選擇的。
圖6:用于可解釋性分析的置信度得分 案例研究。此處的子標題表示該節(jié)點是按GCN/GAT作為教師模型,按大或小值選擇的。
四、結(jié)論
在本文中,我們提出了一種有效的知識蒸餾框架,可以提取任意預訓練的GNN (教師模型) 的知識并將其注入精心設計的學生模型中。
學生模型CPF被建立為兩個簡單預測機制的可訓練組合:標簽傳播和特征轉(zhuǎn)換,二者分別強調(diào)基于結(jié)構(gòu)的先驗知識和基于特征的先驗知識。蒸餾后,學習的學生可以利用先驗知識和GNN知識,從而超越GNN老師。
在五個基準數(shù)據(jù)集上的實驗結(jié)果表明,我們的框架可以通過更可解釋的預測過程來一致,顯著地改善所有七個GNN教師模型的分類精度。在不同數(shù)量的訓練比率和傳播層數(shù)上進行的附加實驗證明了我們算法的魯棒性。我們還提供了案例研究,以了解學生架構(gòu)中學習到的平衡參數(shù)和置信度得分。
在未來的工作中,除了半監(jiān)督節(jié)點分類之外,我們還將探索將我們的框架用于其他基于圖的應用。例如,無監(jiān)督節(jié)點聚類任務會很有趣,因為標簽傳播模式在沒有標簽的情況下不能應用。另一個方向是改進我們的框架,鼓勵教師和學生模型互相學習,以取得更好的成績。
相關文章
- Unity著手推進神經(jīng)網(wǎng)絡渲染技術(shù)應用,顛覆呈現(xiàn)虛擬 3D 世界的方式
- 科研成果發(fā)布│基于超圖神經(jīng)網(wǎng)絡的推薦系統(tǒng)論文
- 打破神經(jīng)網(wǎng)絡技術(shù)應用局限性,度小滿博士后論文入選國際頂級會議
- 本科生新算法打敗NeRF,不用神經(jīng)網(wǎng)絡照片也能動起來,提速100倍
- 特斯拉AI DAY:AI神經(jīng)網(wǎng)絡解讀 Dojo超算信息/AI機器人發(fā)布
- 2021世界人工智能大會AI Debate:圖神經(jīng)網(wǎng)絡是否是實現(xiàn)認知智能的關鍵?
- 新專利顯示蘋果VR頭顯可能利用神經(jīng)網(wǎng)絡監(jiān)測用戶的姿勢
- 圖神經(jīng)網(wǎng)絡的知識提取與超越:一個有效的知識蒸餾框架
- 人工神經(jīng)網(wǎng)絡秒變脈沖神經(jīng)網(wǎng)絡,新技術(shù)有望開啟邊緣AI計算新時代
- 深度神經(jīng)網(wǎng)絡是為人工智能的重要基石
- Imagination推出新神經(jīng)網(wǎng)絡加速器 可用于ADAS和自動駕駛
- 深度學習與神經(jīng)網(wǎng)絡推動AI芯片市場以約40%的年成長率持續(xù)擴張
- 百度飛槳PGL-UniMP刷新3項任務記錄 登頂圖神經(jīng)網(wǎng)絡權(quán)威榜單OGB
- Helm.ai宣布了一種新的深度教學方法來訓練神經(jīng)網(wǎng)絡
- 科學家們致力于利用神經(jīng)網(wǎng)絡改變神經(jīng)成像研究
- 開發(fā)AI神經(jīng)網(wǎng)絡用于打假 阿里安全獲計算機視覺頂會ECCV2020競賽冠軍
人工智能企業(yè)
更多>>人工智能產(chǎn)業(yè)
更多>>- AIDC產(chǎn)業(yè)發(fā)展大會隆重召開,開啟AIDC新紀元
- 絢星破局AI落地困境,四大業(yè)務重構(gòu)企業(yè)智能生產(chǎn)力新范式
- 騰訊啟動AI應用繁榮計劃,新一期AI共創(chuàng)營報名企業(yè)超300家
- 首都機場“AI繪空港”大賽完美收官,卓特視覺以技術(shù)賦能創(chuàng)意未來
- 打造張江人工智能創(chuàng)新小鎮(zhèn),全國首個人工智能創(chuàng)新應用先導區(qū)再添發(fā)展新引擎
- 人機共生 · 智啟未來——2025高交會亞洲人工智能與機器人產(chǎn)業(yè)鏈展主題發(fā)布
- 北京數(shù)基建發(fā)布“知行IntAct”混合智能體產(chǎn)品,以AI定義城市治理新范式
- 新時達“精耕小腦”,與大腦協(xié)同,加速具身智能垂直落地
人工智能技術(shù)
更多>>- 騰訊開源框架 Kuikly 再升級!率先適配 “液態(tài)玻璃”,原生體驗更極致
- 外灘大會首發(fā)! 螞蟻密算推出AI密態(tài)升級卡 實現(xiàn)零改動“即插即用”
- 騰訊優(yōu)圖攜Youtu-Agent開源項目亮相上海創(chuàng)智學院首屆TechFest大會
- 2025外灘大會:王堅暢談AI變革,普天科技錨定空天算力新賽道
- 騰訊正式開源Youtu-GraphRAG,圖檢索增強技術(shù)迎來落地新突破
- 聲網(wǎng)兄弟公司 Agora與OpenAI 攜手 助力多模態(tài) AI 智能體實現(xiàn)實時交互
- Qwen-Image-Edit 模型上線基石智算,圖像編輯更精準
- 火山引擎多模態(tài)數(shù)據(jù)湖落地深勢科技,提升科研數(shù)據(jù)處理效能