Introducing Google TensorFlow-GNN 1.0: Innovative Dynamic and Interactive Sampling for Building Graph Neural Networks at Scale

Graph neural networks have been widely used since their birth and can represent the relationships between different objects in the world. Today, the Google team officially announced the release of TensorFlow-GNN 1.0, a production-tested library for building GNNs at scale.

In 2005, the release of the epoch-making work “The Graph Neural Network Model” brought graph neural networks to everyone. Prior to this, the way scientists processed graph data was to convert the graph into a set of “vector representations” during the data preprocessing stage. The emergence of CNN has completely changed the disadvantages of this information loss. In the past 20 years, generations of models have continued to evolve, promoting progress in the field of ML.

Advertisement

Today, Google officially announced the release of TensorFlow GNN 1.0 (TF-GNN) – a production-tested library for building GNNs at scale.

It supports both modeling and training in TensorFlow and the extraction of input graphs from large data stores.

TF-GNN is built from scratch for heterogeneous graphs, where types of objects and relationships are represented by different sets of nodes and edges.

Objects and their relationships in the real world appear in different types, and the heterogeneous focus of TF-GNN makes representing them very natural.

Advertisement

Google scientist Anton Tsitsulin says complex heterogeneous modeling is back!

TF-GNN 1.0 first launched

Objects and their relationships are everywhere in our world.

Relationships are as important to understanding an object as looking at the properties of the object itself in isolation, such as transportation networks, production networks, knowledge graphs or social networks.

Discrete mathematics and computer science have long formalized such networks as graphs, consisting of “nodes” arbitrarily connected by edges in various irregular ways.

However, most machine learning algorithms only allow regular and uniform relationships between input objects, such as grids of pixels, sequences of words, or no relationships at all.

Graph Neural Network, or GNN for short, is a powerful technology that can exploit both the connectivity of graphs (such as early algorithms DeepWalk and Node2Vec) and the input features of different nodes and edges.

GNNs can learn about the graph as a whole (does this molecule react in some way?), individual nodes (according to the quote, what is the topic of this document?), potential edges (is this product likely to be related to another products purchased together?) to make predictions.

In addition to making predictions on graphs, GNNs are a powerful tool for bridging the gap from more typical neural network use cases.

They encode the discrete relational information of a graph in a continuous manner, allowing it to be naturally incorporated into another deep learning system.

Google today officially announced TensorFlow GNN 1.0 (TF-GNN), a production-tested library for building GNNs at scale.

In TensorFlow, such a graph is given by tfgnn.GraphTensor Object representation of the type.

This is a composite tensor type (a collection of tensors in a Python class), in tf.data.Dataset , tf.function He was accepted as a “first-class subject” among others.

It can store not only the graph structure, but also the features of nodes, edges and the entire graph.

Trainable transformations for GraphTensors can be defined as Layers objects in the high-level Kera API, or used directly tfgnn.GraphTensor Primitives.

GNN: Making predictions about objects in context

Predict the attributes of certain types of nodes in a graph defined by a cross-reference table in a large database.

For example, in the computer science (CS) citation database arxiv papers, there are one-to-many citations and many-to-one citation relationships, which can predict the subject area of ​​each paper.

Like most neural networks, GNNs are trained on a dataset of many labeled samples (on the order of millions), but each training step consists of a batch of much smaller training samples (say, hundreds).

To scale to millions of samples, GNNs are trained on a stream of reasonably small subgraphs in the underlying graph. Each subgraph contains enough raw data to calculate the GNN results of the central labeled node and train the model.

This process, often called subgraph sampling, is extremely important for GNN training.

Most existing tools complete sampling in a batch manner, generating static subgraphs for training.

TF-GNN provides tools to improve this through dynamic and interactive sampling.

Subgraph sampling process, which extracts small, actionable subgraphs from a larger graph to create input examples for GNN training

TF-GNN 1.0 introduces a flexible Python API for configuring dynamic or batch subgraph sampling at all relevant scales: interactive sampling in Colab notes.

Specifically, “efficient sampling” of small data sets stored in the main memory of a single training host, or using Apache Beam to sample huge data sets (up to hundreds of millions of nodes and billions of edges) stored in a network file system. ) for distributed sampling.

On these same sampled subgraphs, the GNN's task is to compute the hidden (or latent) state of the root node; the hidden states aggregate and encode relevant information about the root node's neighborhood.

A common method is “message passing neural network”.

In each round of messaging, a node receives messages from neighboring nodes along incoming edges and updates its own hidden state from these edges.

After n rounds, the hidden state of the root node reflects the aggregated information of all nodes within n edges (as shown in the figure below, n=2). The message and new hidden state are computed by the hidden layer of the neural network.

In heterogeneous graphs, it often makes sense to use separately trained hidden layers for different types of nodes and edges.

The picture shows a simple “message passing neural network”. In this network, the node status will be propagated from the external node to the internal node at each step, and the new node status will be calculated at the internal node.Once the root node is reached, the final prediction can be made

The training setup is done by placing the output layer on top of the hidden state of the GNN with labeled nodes, computing the loss (to measure the prediction error) and updating the model weights via backpropagation, as in any neural network training Common.

In addition to supervised training, GNNs can also be trained in an unsupervised manner, allowing us to compute continuous representations (or embeddings) of discrete graph structures of nodes and their features.

These representations are then commonly used in other ML systems.

In this way, the discrete relational information encoded by the graph can be incorporated into more typical neural network use cases. TF-GNN enables fine-grained specification of unsupervised targets on heterogeneous graphs.

Build GNN architecture

The TF-GNN library supports the construction and training of GNNs at different levels of abstraction.

At the highest level, users can use any of the predefined models bundled with the library, which are represented as Kera layers.

In addition to a small set of models from the research literature, TF-GNN comes with a highly configurable model template that provides carefully curated modeling choices.

Google found these choices to provide a strong baseline for many of our internal questions. The template implements the GNN layer; users only need to initialize from the Kera layer.

import tensorflow_gnn as tfgnnfrom tensorflow_gnn.models import mt_albis
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):  """Builds a GNN as a Keras model."""  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
  # Encode input features (callback omitted for brevity).  graph = tfgnn.keras.layers.MapFeatures(      node_sets_fn=set_initial_node_states)(graph)
  # For each round of message passing...for _ in range(2):    # ... create and ly a Keras layer.    graph = mt_albis.MtAlbisGraphUpdate(        units=128, message_dim=64,        attention_type="none", simple_conv_reduce_type="mean",        normalization_type="layer", next_state_type="residual",        state_dropout_rate=0.2, l2_regularization=1e-5,    )(graph)
  return tf.keras.Model(inputs, graph)

At the lowest level, users can write GNN models from scratch based on primitives for passing data around the graph, such as broadcasting data from a node to all its outgoing edges, or pooling data into a node from all its incoming edges. .

TF-GNN’s graph data model treats nodes, edges, and the entire input graph equally when it comes to features or hidden states.

Therefore, it can not only directly represent node-centric models like MPNN, but also more general forms of graph networks.

This can (but not necessarily) be done using Kera as a modeling framework on top of core TensorFlow.

Training choreography

While advanced users are free to perform custom model training, the TF-GNN Runner also provides a concise way to coordinate the training of Kera models in common situations.

A simple call might look like this:

from tensorflow_gnn import runner
runner.run(   task=runner.RootNodeBinaryClassification("papers" ),   model_fn=model_fn   trainer=runner.KerasTrainer(tf.distribute.MirroredStrategy() model_dir="/tmp/model")   optimizer_fn=tf.keras.optimizers.Adam   epochs=10   global_batch_size=128   train_ds_provider=runner.TFRecordDatasetProvider("/tmp/train*")   valid_ds_provider=runner.TFRecordDatasetProvider("/tmp/validation*")   gtspec=)

Runner provides ready-made solutions for ML Pain, such as distributed training and fixed-shape training on cloud TPU. tfgnn.GraphTensor filling.

In addition to single-task training (shown above), it also supports joint training of multiple (two or more) tasks.

For example, unsupervised tasks can be mixed with supervised tasks to form a final continuous representation (or embedding) with application-specific inductive biases. The caller simply replaces the task parameters with the task map:

from tensorflow_gnn import runnerfrom tensorflow_gnn.models import contrastive_losses
runner.run(     task={        "classification" runner.RootNodeBinaryClassification("papers" ),        "dgi" contrastive_losses.DeepGraphInfomaxTask("papers")      }    )

Additionally, TF-GNN Runner includes an ensemble gradient implementation for model attribution.

The ensemble gradient output is a GraphTensor with the same connectivity as the observed GraphTensor, but its features are replaced by gradient values. In GNN prediction, larger gradient values ​​contribute more than smaller gradient values.

In summary, Google hopes that TF-GNN will help promote the large-scale application of GNN in TensorFlow and promote further innovation in this field.

References:

  • https://blog.tensorflow.org/2024/02/graph-neural-networks-in-tensorflow.html?linkId=9418266

  • https://blog.research.google/2024/02/graph-neural-networks-in-tensorflow.html

Advertisement