Sorry, your browser cannot access this site
This page requires browser support (enable) JavaScript
Learn more >

data封装了一个类似图的数据结构

1
2
3
4
5
6
7
8
9
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])
1
2
3
4
5
6
7
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
print("打印数据集的图表数量,类别数量,节点特征数量")
print(len(dataset), dataset.num_classes, dataset.num_node_features)
data = dataset[0] # Get the first graph object.
print(data.num_nodes)
>>> 600 6 3
>>> 37

加入dataloader

1
2
3
4
5
6
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 每个batch 32个图表
for batch in loader:
print(batch)

1696958380568

代表这个节点在哪张图里

1
2
3
4
5
6
7
8
9
10
for data in loader:
data
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

data.num_graphs
>>> 32

x = scatter(data.x, data.batch, dim=0, reduce='mean')
x.size()
>>> torch.Size([32, 21])

分析scatter函数的汇总过程

  1. 输入数据:

    • data.x: 这是节点特征矩阵,其形状为[1082, 21]。这意味着在这个批处理中,总共有1082个节点,每个节点有21个特征。
    • data.batch: 这是一个批次向量,其长度为1082(与 data.x中的节点数相同)。它指定了每个节点属于哪个图。
    • data.num_graphs: 这表示在批处理中有32个图。
  2. 使用scatter进行汇总:

    • 你使用 scatter函数并指定 reduce='mean',这意味着你想要计算每个图中的节点特征的平均值。
    • 所以,对于每个图,它会考虑所有属于该图的节点(基于 data.batch),然后计算这些节点特征的平均值。结果是每个图都有一个平均的特征向量。
  3. 输出:

    • x: 经过 scatter函数处理后,你得到一个新的张量 x,其形状为[32, 21]。这意味着现在你有32个平均特征向量,每个向量有21个特征(与输入 data.x中的特征数相同)。
    • 每个平均特征向量对应于批处理中的一个图。这样,你从批处理中的1082个节点减少到了32个平均特征向量,每个都代表一个图。

结论:
scatter的汇总工作是将这批处理中的每个图中的所有节点特征取平均,从而为每个图得到一个代表性的特征向量。这对于那些需要从整个图中提取特征的任务(例如,整图分类)非常有用。

数据转换

1
2
3
4
5
6
7
8
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
pre_transform=T.KNNGraph(k=6))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

该代码把点云图,找到每个点最邻近的6个点,加上边连接,所以15108 = 2518 * 6,数据集专注于飞机类别。而y属性中的值将表示飞机的各个部分(如机翼、尾翼、机身等)

pre_transform vs transform

pre_transform:

  • 定义:
    • pre_transform 是在保存数据集到磁盘之前应用的变换。
  • 应用时机:
    • 它只会执行一次,即在第一次下载和处理原始数据时。之后,即使你多次加载数据集,这个变换也不会再被应用。
  • 示例解释:
    • 在你的例子中,pre_transform=T.KNNGraph(k=6) 的意思是在数据被保存之前,对每个图形数据计算其k-最近邻图。具体来说,它会为每个点找到其6个最近的邻居,并创建一个边连接索引,表示这些近邻关系。

transform:

  • 定义:
    • transform 是每次从数据集中获取数据时都会应用的变换。
  • 应用时机:
    • 它是动态的,意味着每次你从数据集中提取一个数据项时,都会实时地应用这个变换。
  • 示例解释:
    • 在你的例子中,transform=T.RandomJitter(0.01) 的意思是每次提取数据时,都会对每个点的位置随机添加一个在 [-0.01, 0.01]范围内的扰动。这种随机扰动有助于数据增强,使模型更加鲁棒。

结论:
pre_transform 是一个预处理步骤,只应用一次,而 transform 是一个每次都会应用的数据增强步骤。

图的学习方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
print("打印数据集的图表数量,类别数量,节点特征数量")
print(len(dataset), dataset.num_classes, dataset.num_node_features)
print("打印节点标签")
print(dataset[0].keys)
data = dataset[0] # Get the first graph object.
print(data.num_nodes)

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)

def forward(self, data):
x, edge_index = data.x, data.edge_index

x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)

return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

model.eval()
pred = model(data)
print(pred.shape)
pred = pred.argmax(dim=1)
print (pred)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

可以看到,图学习方法将边的信息融合了进去

现在让我们注意dim参数,dim的意思是,在某个维度上处理数据,argmax在7的概率分布中选择最大的那个

我们可以观察输出

1
2
3
4
5
6
7
8
打印数据集的图表数量,类别数量,节点特征数量
1 7 1433
打印节点标签
['y', 'x', 'train_mask', 'val_mask', 'test_mask', 'edge_index']
2708
torch.Size([2708, 7])
tensor([3, 4, 4, ..., 5, 3, 3], device='cuda:0')
Accuracy: 0.7880

关于这个mask是什么

1
tensor([ True,  True,  True,  ..., False, False, False])

可以看出,就是为了区别是否加入训练

Exercises

  1. What does edge_index.t().contiguous() do?

  2. Load the "IMDB-BINARY" dataset from the TUDataset benchmark suite and randomly split it into 80%/10%/10% training, validation and test graphs.

  3. What does each number of the following output mean?

    1
    2
    print(batch)
    >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

edge_index.t().contiguous()edge_index 张量转置并确保其在内存中连续存储。

评论