Skip to content

Latest commit

 

History

History
209 lines (143 loc) · 17.1 KB

File metadata and controls

209 lines (143 loc) · 17.1 KB

Графовые нейроные сети

1. Введение

Графовые нейронные сети (GNN) — это класс моделей глубокого обучения, предназначенный для работы с данными в виде графов. Они применяются для решения задач:

  • Классификации узлов (node classification). Например, тема научной статьи в графе цитирования.
  • Классификации графов (graph classification). Предсказание/выявление метки всего графа. Например, токсичность молекулы.
  • Предсказания связей (link prediction). Например, захочет ли пользователь $X$ купить товар $Y$.

GNN обучаются извлекать векторные представления (эмбеддинги) узлов и графов, используя структурную информацию.

2. Основные архитектуры и этапы работы

Входные данные:

  • Граф $G=(V,E)$ , где $V$ — вершины графа (узлы), $E$ — рёбра. $N$ – число вершин, $n = |V|$
  • Матрица смежности $A \in {0,1}^{N \tiems N}$
  • Матрица исходных признаков узлов $X \in \mathbb{R}^{N \times F}$, где $F$ – число признаков.

Процесс обработки: Пусть $H^{(l)} \in R^{N \times D_l}$ — матрица внутренних векторных представлений узлов на слое $l$, при этом $H^{(0)} = X$, где $D_l$ – размерность внутреннего представления на слое $l$

Общая схема

  • Входной слой: признаки узлов $X$ подаются на вход.
  • Графовая нейроная сеть: последовательное применение слоёв GNN для получения скрытых представлений улов.
  • Голова предсказания (prediction head): финальный слой (например, полносвязный), преобразующий векторные представлеения в предсказания.
  • Функция потерь: вычисление ошибки между предсказаниями и истинными метками.

3. Простая свёрточная графовая сеть (GCN – Graph Convolutional Network)

GCN реализует аппроксимацию спектральной свертки на графах. Основная операция для одного слоя выглядит следующим образом:

$H^{(l+1)} = \sum( Â H^{(l)} Θ^{(l)} )$

где:

  • $Â = D̃^{-1/2} Ã D̃^{-1/2}$ — нормализованная матрица смежности с добавленными self-loop'ами.
  • $Ã = A + I_N$ — матрица смежности с добавленными self-loop'ами (петлями, учитывающими собственный узел). $I_N$ — единичная матрица.
  • $D̃$ — диагональная матрица степеней узлов для $Ã$, то есть $D̃_{ii} = \sum_j Ã_{ij}$.
  • $Θ^{(l)} ∈ R^{D_l × D_{l+1}}$ — обучаемая матрица весов на слое $l$.
  • $\sigma(·)$ — нелинейная функция активации (например, ReLU).

Интуиция: Данная форма может быть интерпретирована как замена простого усреднения представлений соседей (message passing) с последующим линейным преобразованием. Нормализация с помощью помогает избежать проблем с градиентами и учитывает разную степень узлов.


Пример реализации простой GCN на PyTorch

import torch
from torch import nn

class GCN(nn.Module):
    def __init__(self, *sizes):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(x, y) for x, y in zip(sizes[:-1], sizes[1:])
        ])
    
    def forward(self, vertices, edges):
        adj = torch.eye(len(vertices))
        adj[edges[:,0], edges[:,1]] = 1
        adj[edges[:,1], edges[:,0]] = 1
        for layer in self.layers:
            vertices = torch.sigmoid(layer(adj @ vertices))
        return vertices

Популярность GCN: простота, эффективность, хорошая производительность (наиболее цитируемая работа в области GNN) (Kipf & Welling, ICLR 2017).


4. GraphSAGE (SAmple and aggreGatE)

GraphSAGE обобщает подход GCN, вводя обучаемые функции агрегации. Алгоритм для одного узла $v$ на слое $l$ состоит из трех шагов:

  1. Агрегация сообщений от соседей: $m_{N(v)}^{(l)} = \text{AGGREGATE}^{(l)}( { h_u^{(l)} : u ∈ N(v) } )$

    Здесь N(v) — множество соседей узла $v$, а AGGREGATE — некоторая дифференцируемая функция. Примеры:

    • Среднее: $$\text{AGGREGATE} = \sum_{u ∈ N(v)} \frac{h_u^{(l)}}{|N(v)|}$$
    • Пулинг (Max-Pooling): $$\text{AGGREGATE} = γ( { MLP( h_u^{(l)} ) : u \in N(v) } )$$ где $γ$ — поэлементная функция максимума, а $MLP$ — многослойный перцептрон.
  2. Объединение с собственным представлением: $$h_v^{(l+1)} = \sum( Θ^{(l)} · \text{CONCAT}( h_v^{(l)}, m_{N(v)}^{(l)} ) ),$$

    где $Θ^{(l)}$ — обучаемая матрица весов, а CONCAT — операция конкатенации.

  3. Нормализация (опционально, но часто используется):

    $$ h_v^{(l+1)} = \frac{h_v^{(l+1)}}{\Vert h_v^{(l+1)}\Vert_2 } $$

5. Graph Attention Networks (GAT)

GAT вводит механизм внимания, который назначает веса соседям узла в процессе агрегации. Для слоя $l$ обновление представлений происходит следующим образом.

  1. Вычисление необработанных коэффициентов внимания (Energy): Для каждой пары соседних узлов $v$ и $u$ (где $u ∈ N(v) ∪ {v}$, часто включая сам узел $v$) вычисляется необработанный коэффициент внимания $e_{vu}$. Это скаляр, который отражает важность узла $u$ для узла $v$.

    $e_{vu} = LeakyReLU ( a^T · [ Θ h_v^{(l)} ‖ Θ h_u^{(l)} ] )$

    где:

    • $ Θ \in R^{D' \times D_l}$ — общая для всех узлов обучаемая матрица весов линейного преобразования. $D'$ — это размерность представления внутри механизма внимания, гиперпараметр модели. Она задаёт размерность пространства, в котором будет вычисляться сходство.
    • $\Vert$ обозначает операцию конкатенации векторов.
    • $a \in R^{2D'}$ — обучаемый вектор-параметр функции внимания.
    • LeakyReLU — функция активации (обычно с небольшим отрицательным наклоном, например, 0.2).
  2. Нормализация коэффициентов с помощью Softmax: Чтобы сделать коэффициенты сравниммими и интерпретируемыми как вероятности важности, по всем соседям u узла v применяется функция softmax. Нормализованный коэффициент внимания обозначим как $α_{vu}$.

    $$ α_{vu} = \text{softmax}u (e{vu}) = \frac{exp(e_{vu})} {\sum_{k ∈ N(v) \cup {v}} exp(e_{vk})}$$

    Важно: Суммирование происходит по всем узлам $k$ из множества соседей $v$, включая его самого (если используются self-loops).

  3. Агрегация с взвешенной суммой: Итоговое представление узла $v$ на следующем слое получается как взвешенная (по нормализованным коэффициентам внимания) сумма преобразованных представлений его соседей.

    $$h_v^{(l+1)} = σ ( \sum_{u ∈ N(v) ∪ {v}} α_{vu} Θ h_u^{(l)} )$$

Многоголовое внимание (Multi-head attention): Для повышения выразительности и стабильности обучения механизм внимания повторяют $K$ раз (с независимыми параметрами $Θ^k$ и $a^k$). Выходы разных "голов" обычно конкатенируются (для скрытых слоев) или усредняются (для выходного слоя).

$$ h_v^{(l+1)} = \Vert_{k=1}^K σ ( \sum_{u ∈ N(v) \cup {v}} α_{vu}^k Θ^k h_u^{(l)} ) $$

где $\Vert$ обозначает конкатенацию, а $α_{vu}^k$ — коэффициент внимания для $k$-ой "головы".

6. Message Passing Neural Networks (MPNN): Обобщающий фреймворк**

Фреймворк Message Passing Neural Networks (MPNN) предоставляет общую математическую формулировку, которая инкапсулирует многие популярные модели GNN, включая GCN, GraphSAGE и GAT. Он формализует идею передачи сообщений через граф в виде двух четко определенных шагов.

Пусть:

  • $h_v^{(l)}$ — векторное представление узла $v$ на слое $l$.
  • $e_{vw}$ — возможные атрибуты ребра между узлами $v$ и $w$ (может быть опущено, если атрибутов нет).
  • $N(v)$ — множество соседей узла $v$.

Процесс MPNN для одного слоя состоит из двух фаз:

  1. Фаза передачи сообщений (Message Passing Phase): Для каждого узла $v$ и каждого его соседа $w \in N(v)$ вычисляется "сообщение" $m_{vw}^{(l)}$. Это сообщение является функцией от представлений узла-отправителя, узла-получателя и атрибутов ребра. $$ m_{vw}^{(l)} = M_l ( h_v^{(l)}, h_w^{(l)}, e_{vw} ) $$ где $M_l(\cdot)$ — это дифференцируемая функция сообщения (message function) на слое $l$ (например, небольшая нейросеть или линейное преобразование).

  2. Фаза обновления узла (Node Update Phase): Каждый узел $v$ агрегирует все входящие сообщения от своих соседей и обновляет собственное представление, комбинируя агрегированные сообщения со своим предыдущим состоянием. $$ h_v^{(l+1)} = U_l ( h_v^{(l)}, m_v^{(l)} ) $$ где:

    • $m_v^{(l)}$ — агрегированное сообщение для узла $v$. Оно получается с помощью перестановочно-инвариантной функции агрегации $AGG_l$ (например, суммирование, усреднение, максимум): $$ m_v^{(l)} = AGG_l ( { m_{vw}^{(l)} \mid w \in N(v) } ) $$
    • $U_l(\cdot)$ — это дифференцируемая функция обновления (update function) на слое $l$ (например, RNN или полносвязный слой).

Связь MPNN с другими архитектурами:

  • GCN: Может быть выражена в рамках MPNN следующим образом:

    • $M_l ( h_v^{(l)}, h_w^{(l)} ) = (1 / \sqrt{\deg(v)\deg(w)}) \cdot h_w^{(l)}$ (сообщение — это нормализованное представление соседа).
    • $AGG_l = \mathrm{SUM}$ (агрегация — суммирование).
    • $U_l ( h_v^{(l)}, m_v^{(l)} ) = \sigma( \Theta^{(l)} m_v^{(l)} )$ (обновление — линейное преобразование и нелинейность).
  • GraphSAGE:

    • $M_l ( h_v^{(l)}, h_w^{(l)} ) = h_w^{(l)}$ (сообщение — представление соседа).
    • $AGG_l$ — обучаемая функция агрегации (например, усреднение или LSTM).
    • $U_l ( h_v^{(l)}, m_v^{(l)} ) = \sigma ( \Theta^{(l)} \cdot \mathrm{CONCAT}( h_v^{(l)}, m_v^{(l)} ) )$ (обновление через конкатенацию и преобразование).
  • GAT:

    • $M_l ( h_v^{(l)}, h_w^{(l)} ) = \alpha_{vw} \Theta h_w^{(l)}$ (сообщение — это взвешенное по вниманию преобразованное представление соседа).
    • $AGG_l = \mathrm{SUM}$ (агрегация — взвешенная сумма).
    • $U_l ( h_v^{(l)}, m_v^{(l)} ) = \sigma( m_v^{(l)} )$ (обновление — применение нелинейности к агрегированному сообщению).

Этот фреймворк является теоретической основой, показывающей, что, несмотря на внешние различия, большинство GNN следуют одной и той же фундаментальной парадигме — передаче и агрегации сообщений.

7. Постановка эксперимента. Типы обучения.

7.1 Трансдуктивное обучение

Transductive setup

  • В обучении, тесте и валидации используется 1 граф
  • Все узлы и все рёбра присутствуют во время обучения
  • На этапе обучения модель видит как помеченные, так и непомеченные узлы
  • Примеры: Cora, Citeseer, Pubmed.

7.2 Индуктивное обучение

Inductive setup

  • Модель обучается и тестируется на разных графах. То есть на одном наборе графов обучается и применяется к новым, ранее невиданным.
  • Пример: PPI (Protein-Protein Interaction).

7.3 Пример преимущество трансдукции на общем случае


General transductive setup

"Преимущество трансдукции заключается в возможности учитывать все точки, а не только помеченные, при выполнении задачи классификации. В этом случае трансдуктивные алгоритмы будут классифицировать непомеченные точки в соответствии с кластерами, к которым они естественным образом принадлежат. Таким образом, точки в середине, скорее всего, будут помечены как «B», поскольку они расположены очень близко к этому кластеру" (цитата из Википедии )

8. Инструменты и библиотеки