General Algorithmsο
PyG-SSL includes state-of-the-art algorithms for Self-Supervised Learning on arbitrary/general graphs. Many of the SSL algorithms are contrastive-learning-based, and PyG-SSL also include two non-contrastive learning algorithms, BGRL and AFGRL. In this page, we provide the detailed information of how to use each algorithm by PyG-SSL.
DGIο
Deep Graph Infomax (DGI) is a self-supervised learning method for learning node and graph-level representations. It leverages the concept of mutual information to maximize the dependency between local (node) and global (graph) representations.
Key Conceptsο
- Global Representation:
The global representation is obtained by applying a readout function over the node embeddings produced by a graph neural network (GNN). This readout function aggregates information from all nodes to produce a single vector representing the entire graph.
- Node Representation:
Each node in the graph is encoded using a GNN, resulting in individual embeddings for each node.
- Mutual Information Maximization:
DGI maximizes mutual information between the node-level representations and the global graph-level representation. This is achieved by contrasting the representations of the original graph with those of a corrupted version of the graph. The corrupted graph is generated by randomly shuffling node features, creating a negative sample.
Learning Objectiveο
The learning objective of DGI is to distinguish between:
Positive Pairs: These consist of node and global embeddings from the original graph.
Negative Pairs: These are pairs of node embeddings from the corrupted graph and global embeddings from the original graph.
DGI uses a discriminator function to distinguish between these two types of pairs. By maximizing the mutual information between positive pairs, the model learns better node and graph-level representations.
API Reference in PyG-SSLο
- class DGIEncoder(in_channels, hidden_channels=512, num_layers=1, act=torch.nn.PReLU()**kwargs)ο
A graph encoder that generate representations given
torch_geometric.data.Dataset.Parameters:ο
- in_channels (int, optional):
Number of input features of the input dataset.
- hidden_channels (int, optional):
Number of hidden channels for the encoder. Default is 512.
- num_layers (int, optional):
Number of layers for the encoder. Default is 1.
- act (torch.nn.Module, optional):
Activation function for the encoder. Default is
torch.nn.PReLU().
- kwargs (optional):
Additional arguments for
pygssl.methods.DGIEncoder.
- class DGI(encoder: torch.nn.Module, hidden_channels: int, readout: str = 'avg', readout_act: Callable = torch.nn.Sigmoid())ο
The Deep Graph Infomax Algorithm.
Parameters:ο
- encoder (Optional[
torch.nn.Module]): The encoder to be trained.
- encoder (Optional[
- hidden_channels (int):
Output dimension of the encoder.
- readout (str):
βavgβ or βmaxβ, specifies how to generate global embeddings. (default: βavgβ)
Referencesο
VeliΔkoviΔ, P., et al. βDeep Graph Infomax.β International Conference on Learning Representations (ICLR), 2019.
GraphCLο
Graph Contrastive Learning (GraphCL) is a method in the field of machine learning that focuses on learning representations of graph-structured data by leveraging contrastive learning techniques. This approach helps in understanding and improving how models learn features from graph data.
Key Conceptsο
Augmentations: Techniques applied to the original graph to generate variations. These can include node dropout, edge perturbation, and subgraph sampling. The idea is to create different views of the same graph to learn robust representations.
Graph Representations: The embeddings or features learned from graphs that capture their structural and semantic information. These representations are used for various downstream tasks such as node classification, link prediction, and graph classification.
Loss Functions: Functions used to measure the difference between the predicted and actual similarities. Common loss functions in GraphCL include the InfoNCE loss and other contrastive losses tailored for graph data.
Learning Objectivsο
Positive Pairs: Constructed by applying different augmentations to the same original graph. For example, if you have a graph G, applying node dropout to G to create Gβ and edge perturbation to G to create Gββ would result in positive pairs (Gβ, Gββ) because they represent different views of the same underlying graph structure.
Negative Pairs: Constructed by contrasting graphs from different samples. For instance, if you have two different graphs G1 and G2, they form a negative pair (G1, G2) because they represent different graph structures. The model is trained to ensure that the similarity between these negative pairs is minimized compared to the positive pairs.
API Reference in PyG-SSLο
- class GraphCLEncoder(in_channels, hidden_channels=512, num_layers=1, act=torch.nn.PReLU()**kwargs)ο
A graphcl encoder that generate representations given
torch_geometric.data.Dataset.Parameters:ο
- in_channels (int, optional):
Number of input features of the input dataset.
- hidden_channels (int, optional):
Number of hidden channels for the encoder. Default is 512.
- num_layers (int, optional):
Number of layers for the encoder. Default is 1.
- act (torch.nn.Module, optional):
Activation function for the encoder. Default is
torch.nn.PReLU().
- kwargs (optional):
Additional arguments for
pygssl.methods.GraphCLEncoder.
- class GraphCL(encoder: torch.nn.Module, hidden_channels: int, readout: Callable | torch.nn.Module = AvgReadout(), corruption: AugmentType = RandomMask(), loss_function: torch.nn.Module | None = None)ο
The Graph Contrastive Learning Algorithm.
Parameters:ο
- encoder (Optional[
torch.nn.Module]): The encoder to be trained.
- encoder (Optional[
- hidden_channels (int):
Output dimension of the encoder.
- readout (str):
βavgβ or βmaxβ, specifies how to generate global embeddings. (default: βavgβ)
- corruption (str):
Augmentation type to be used. (default: βRandomMaskβ)
- loss_function (Optional[
torch.nn.Module]): The loss function to be used. (default: None)
- loss_function (Optional[
Referencesο
Graph Contrastive Learning with Augmentations. Yuning You, et al. Available at https://arxiv.org/abs/2010.13902
MVGRLο
MVGRL, or Contrastive Multi-View Representation Learning on Graphs, is a pioneering framework designed to leverage the inherent multi-view characteristics of graph data. It aims to enhance the quality of node representations by contrasting different views derived from the same graph structure.
The fundamental premise of MVGRL is that each node can be represented in multiple ways, influenced by its neighborhood, structural attributes, and other contextual information. By employing a contrastive learning approach, MVGRL encourages the model to distinguish between similar and dissimilar node representations across these diverse views.
Key Features:ο
Multi-View Learning: MVGRL captures the rich structural information of graphs by integrating multiple perspectives, allowing for a more comprehensive understanding of node relationships.
Contrastive Objective: The use of a contrastive loss function enables the model to focus on distinguishing representations of similar nodes while pushing apart those that are dissimilar.
Scalability: MVGRL is designed to scale effectively to large graphs, making it applicable in various domains, including social networks, biological networks, and recommendation systems.
API Reference in PyG-SSLο
- class MVGRLBaseEncoder(in_channels: int, hidden_channels: int = 512, act: torch.nn = torch.nn.PReLU(), bias: bool = True)ο
This class serves as the foundational encoder for the MVGRL framework. It is responsible for embedding nodes into a latent space, processing input features, and generating initial representations. It likely includes methods for handling various view transformations and learning representations based on the graph structure.
Parameters:ο
- in_channels (int, optional):
Number of input features of the input dataset.
- hidden_channels (int, optional):
Number of hidden channels for the encoder. Default is 512.
- num_layers (int, optional):
Number of layers for the encoder. Default is 1.
- bias (bool, optional):
Whether to include a bias term in the linear transformation. Default is True.
- class MVGRLEncoder(in_channels: int, hidden_channels: int = 512, act: torch.nn = torch.nn.PReLU(), bias: bool = True)ο
This class extends the functionality of
MVGRLBaseEncoder`by incorporating two instances of the base encoder. Each encoder processes different views of the graph simultaneously, capturing diverse information about the nodes. The outputs from these encoders are then combined or contrasted to enhance representation learning through the multi-view paradigm.Parameters:ο
- in_channels (int, optional):
Number of input features of the input dataset.
- hidden_channels (int, optional):
Number of hidden channels for the encoder. Default is 512.
- num_layers (int, optional):
Number of layers for the encoder. Default is 1.
- act (torch.nn.Module, optional):
Activation function for the encoder. Default is
torch.nn.PReLU().
- kwargs (optional):
Additional arguments for
pygssl.methods.MVGRLEncoder.
- class MVGRLDiscriminator(hidden_channels: int = 512)ο
The discriminator class plays a crucial role in the contrastive learning process. It evaluates the similarity between the representations generated by the
MVGRLEncoder`for different views. By applying a contrastive loss function, it helps to distinguish between positive pairs (representations of the same node from different views) and negative pairs (representations of different nodes), guiding the overall training process.Parameters:ο
- hidden_channels (int, optional):
Number of hidden channels for the encoder. Default is 512.
- class MVGRL(encoder: torch.nn.Module, hidden_channels: int, readout: str = 'avg', readout_act: Callable = torch.nn.Sigmoid(), diff: AugmentType = ComputePPR(), is_sparse=False, sample_size: int = 1000)ο
This is the main class that orchestrates the entire MVGRL framework. It integrates the encoder and discriminator components, manages the training loop, and optimizes the model parameters. This class is responsible for feeding data through the encoder, computing the contrastive loss using the discriminator, and updating the model based on the loss values.
Parameters:ο
- encoder (Optional[
torch.nn.Module]): The encoder to be trained.
- encoder (Optional[
- hidden_channels (int):
Output dimension of the encoder.
- readout (str):
βavgβ or βmaxβ, specifies how to generate global embeddings. (default: βavgβ)
- readout_act (torch.nn.Module):
Activation function for readout.
- diff (str):
Augmentation type to be used. (default: βComputePPRβ)
- is_sparse (bool):
Whether the graph is sparse or not. (default: False)
- sample_size (int):
Number of samples to be used. (default: 1000)
Referencesο
Contrastive Multi-View Representation Learning on Graphs, Kaveh Hassani et al. Available at https://arxiv.org/abs/2006.05582.
GCAο
Graph Contrastive Learning with Adaptive Augmentation (GCA) is an innovative framework designed to enhance the learning of graph representations through adaptive data augmentation techniques. GCA addresses the challenges of graph self-supervised learning by introducing a dynamic augmentation strategy that adapts to the specific characteristics of the graph data.
The core idea of GCA is to leverage contrastive learning principles to distinguish between similar and dissimilar node representations while employing augmentations that are tailored to the graph structure. By dynamically selecting and applying augmentations during the training process, GCA effectively enhances the diversity of the learned representations and improves the modelβs robustness.
Key Features:ο
Adaptive Augmentation: GCA employs an adaptive augmentation mechanism that selects augmentations based on the graphβs properties and the learning stage, ensuring that the most relevant transformations are applied to enhance representation learning.
Contrastive Learning Framework: By focusing on contrasting node representations from different views, GCA encourages the model to learn meaningful and discriminative features that capture the underlying graph structure.
Improved Robustness: The dynamic nature of the augmentations helps the model generalize better to various tasks, making it suitable for applications such as node classification, link prediction, and community detection.
API Reference in PyG-SSLο
The API of GCA is different from other algorithms in PyG-SSL bacause it requires a specific trainer. The trainer class is pygssl.methods.GCATrainer.
- class GCA_Encoder(in_channels: int, out_channels: int, activation, base_model=GCNConv, k: int = 2, skip=False)ο
Encoder used for GCA.
Parameters:ο
- in_channels (int, optional):
Number of input features of the input dataset.
- out_channels (int, optional):
Number of output channels for the encoder.
- activation (torch.nn.Module, optional):
Activation function for the encoder.
- base_model (torch.nn.Module, optional):
Base GNN model for the encoder. Default is GCNConv.
- k (int, optional):
Number of layers for the encoder. Default is 2.
- skip (bool, optional):
Whether to use skip connections. Default is False.
- class GRACE(encoder: torch.nn.Module, loss_function: None, num_hidden: int, num_proj_hidden: int, tau: float = 0.5)ο
The GCA contrastive loss.
Parameters:ο
- encoder (Optional[
torch.nn.Module]): The encoder to be trained.
- encoder (Optional[
- num_hidden (int):
Output dimension of the encoder.
- num_proj_hidden (int):
Number of hidden channels for the projection head.
- tau (float):
Temperature parameter for the contrastive loss. Default is 0.5.
Referencesο
Graph Contrastive Learning with Adaptive Augmentation. Yanqiao Zhu, et al. Available at https://sxkdz.github.io/files/publications/WWW/GCA/GCA.pdf
SUGRLο
Simple Unsupervised Graph Representation Learning (SUGRL) is a streamlined framework designed to facilitate the learning of node representations in graph data without relying on labeled examples. As graph structures become increasingly prevalent in diverse fields such as social networks, biological networks, and recommendation systems, the demand for effective unsupervised learning techniques has surged. SUGRL addresses this need by providing a straightforward yet powerful approach to extract meaningful representations from complex graph data.
Overviewο
The primary goal of SUGRL is to simplify the process of unsupervised graph representation learning while maintaining effectiveness and efficiency. By leveraging simple yet effective learning strategies, SUGRL enables practitioners to generate high-quality node embeddings that can be applied to a variety of downstream tasks, such as node classification, clustering, and link prediction.
Key Featuresο
Simplicity and Efficiency: SUGRL is designed to be straightforward in its implementation, making it accessible for researchers and practitioners alike. Its efficiency allows for rapid experimentation and integration into existing workflows.
Unsupervised Learning: As an unsupervised method, SUGRL does not require labeled data, making it suitable for scenarios where labeled samples are scarce or unavailable. This characteristic is particularly beneficial in real-world applications where obtaining labels can be challenging and time-consuming.
Effective Representation Learning: SUGRL employs a combination of techniques to learn robust node representations. By utilizing simple training objectives, it captures the essential structural and contextual information present in the graph.
Broad Applicability: The embeddings generated by SUGRL can be utilized across various graph-related tasks, including but not limited to node classification, community detection, and link prediction, enhancing its utility across multiple domains.
Methodologyο
SUGRL operates through a systematic approach that involves several key steps:
Graph Preprocessing: The initial stage involves preprocessing the graph data to extract relevant features and relationships among nodes. This step may include normalization and feature engineering to enhance the quality of input data.
Representation Learning: SUGRL employs unsupervised learning techniques, such as random walks or neighborhood aggregation, to generate node embeddings. These techniques allow the model to learn from the local and global structures of the graph without the need for supervision.
Loss Function Optimization: A carefully designed loss function guides the optimization process, enabling the model to refine the learned representations. This process ensures that the embeddings capture meaningful similarities and differences between nodes.
Evaluation and Application: Once the node representations are learned, SUGRL evaluates the quality of these embeddings through various metrics and applies them to downstream tasks, demonstrating their effectiveness and utility.
API Reference in PyG-SSLο
- class SugrlGCN(in_channels: int, dim_out: int = 128, act: torch.nn = torch.nn.PReLU(), bias: bool = False)ο
A SUGRL encoder that generate representations given
torch_geometric.data.Dataset.Parameters:ο
- in_channels (int, optional):
Number of input features of the input dataset.
- dim_out (int, optional):
Number of output features of the encoder. Default is 128.
- act (torch.nn.Module, optional):
Activation function for the encoder. Default is
torch.nn.PReLU().
- bias (bool, optional):
Whether to include a bias term in the linear transformation. Default is False.
- class SUGRL(encoder: torch.nn.Module, data=None, lr: float = 0.001, weight_decay: float = 0.0, n_epochs: int = 300, use_cuda: bool = True, is_sparse: bool = False, save_root: str = '', device: str = '', loss_function: torch.nn.Module | None = None, config: dict = {})ο
The Simple Unsupervised Graph Representation Learning Algorithm.
Parameters:ο
- encoder (Optional[
torch.nn.Module]): The encoder to be trained.
- encoder (Optional[
- data (Optional[
torch_geometric.data.Dataset]): The dataset to be used for training.
- data (Optional[
- lr (float):
Learning rate for the optimizer. Default is 0.001.
- weight_decay (float):
Weight decay for the optimizer. Default is 0.0.
- n_epochs (int):
Number of epochs for training. Default is 300.
- use_cuda (bool):
Whether to use CUDA for training. Default is True.
- is_sparse (bool):
Whether the graph is sparse. Default is False.
- save_root (str):
Root directory for saving the model. Default is ββ.
- device (str):
Device to be used for training. Default is ββ.
- loss_function (Optional[
torch.nn.Module]): The loss function to be used. Default is None.
- loss_function (Optional[
- config (dict):
Additional configuration parameters. Default is {}.
Referencesο
Simple Unsupervised Graph Representation Learning, Yujie Mo et al. Available at https://ojs.aaai.org/index.php/AAAI/article/view/20748
BGRLο
Overviewο
The primary aim of BGRL is to provide a scalable and robust framework for learning node embeddings that can handle large graphs efficiently. By employing a bootstrapping approach, BGRL iteratively refines node representations, allowing the model to adapt and improve as more graph data becomes available. This method not only enhances the learning process but also ensures that the representations remain relevant and high-quality over time.
Key Featuresο
Scalability: BGRL is specifically designed to scale to large graphs, making it suitable for real-world applications where data sizes can be substantial. Its bootstrapping mechanism allows for efficient processing and updating of node representations without the need for complete retraining.
Bootstrapping Technique: The use of bootstrapping enables the model to generate additional training samples from the existing graph data. This iterative process fosters continuous learning and refinement of node embeddings, ensuring that the model adapts to changes in the graph structure over time.
Effective Representation Learning: BGRL employs various techniques to learn robust node representations, capturing the essential structural and relational information inherent in the graph. By iterating on learned representations, it continuously enhances their quality and relevance.
Versatile Applications: The embeddings produced by BGRL can be applied across a wide range of graph-related tasks, such as node classification, link prediction, and community detection, demonstrating its broad utility in various domains.
Methodologyο
BGRL operates through a series of systematic steps:
Graph Preprocessing: Initially, the graph data is preprocessed to extract relevant features and relationships among nodes. This may involve normalization, feature engineering, and constructing appropriate graph structures for learning.
Initial Representation Learning: BGRL begins by generating initial node embeddings using a chosen representation learning technique. This first pass sets the foundation for subsequent bootstrapping iterations.
Bootstrapping Process: In this phase, BGRL iteratively refines the node embeddings by leveraging bootstrapped samples. This process includes generating new training samples and adjusting the model based on updated representations, fostering continuous improvement.
Loss Function Optimization: The model is optimized using a well-defined loss function that guides the learning process, ensuring that the embeddings accurately reflect the relationships between nodes.
Evaluation and Application: After the node representations are learned and refined, BGRL evaluates their quality using various metrics and applies them to downstream tasks, demonstrating their effectiveness in practical applications.
API Reference in PyG-SSLο
- class BGRLEncoder(in_channel, hidden_channels, **kwargs)ο
A graph encoder for BGRL that generate representations given
torch_geometric.data.Dataset.Parameters:ο
- in_channel (int, optional):
Number of input features of the input dataset.
- hidden_channels (int, optional):
Number of hidden channels for the encoder. Default is 512.
- kwargs (optional):
Additional arguments for
pygssl.methods.BGRLEncoder.
- class BGRL(student_encoder: torch.nn.Module, teacher_encoder: torch.nn.Module, data_augment=None, pred_dim=None)ο
The Bootstrapping Graph Representation Learning Algorithm.
Parameters:ο
- student_encoder (Optional[
torch.nn.Module]): The student encoder to be trained.
- student_encoder (Optional[
- teacher_encoder (Optional[
torch.nn.Module]): The teacher encoder to be trained.
- teacher_encoder (Optional[
- data_augment (Optional[
torch.nn.Module]): The augmentation function to be used. Default is None.
- data_augment (Optional[
- pred_dim (int):
Number of dimensions for prediction. Default is None.
Referencesο
Large-Scale Representation Learning on Graphs via Bootstrapping, Shantanu Thakoor et al. Available at https://arxiv.org/abs/2102.06514
AFGRLο
Augmentation-Free Self-Supervised Learning on Graphs (AFGRL) is an innovative framework that challenges conventional approaches to graph representation learning by eliminating the reliance on data augmentation techniques. As the field of graph learning continues to evolve, the need for efficient and effective self-supervised methods has become increasingly evident. AFGRL addresses this need by introducing a novel self-supervised paradigm that focuses on intrinsic graph structures without the complexities introduced by augmentations.
Introductionο
The central premise of AFGRL is to leverage the rich structural information present within graphs to learn meaningful representations without the necessity for external augmentation. Traditional self-supervised learning methods often depend on various augmentation strategies, which can introduce noise and complicate the learning process. AFGRL simplifies this paradigm by utilizing inherent properties of the graph to guide representation learning, making it a streamlined and effective solution.
Core Conceptsο
Self-Supervised Learning: AFGRL operates within a self-supervised framework, where the model learns to generate embeddings based on the graphβs structure alone. This approach allows for efficient representation learning without labeled data or additional external transformations.
Intrinsic Graph Structures: By focusing on the inherent relationships between nodes and their connectivity, AFGRL captures essential features that characterize the graph. This enables the model to develop robust embeddings that reflect the true nature of the data.
Efficiency and Scalability: The absence of augmentation techniques not only simplifies the training process but also enhances efficiency, allowing AFGRL to scale effectively to large graphs. This is crucial for applications involving extensive datasets where computational resources may be limited.
Methodologyο
The AFGRL framework follows a systematic methodology:
Graph Representation Learning: The process begins with the model analyzing the graph structure to extract relevant features from the nodes and edges. The focus is on understanding local and global patterns without any augmentative interference.
Self-Supervised Objectives: AFGRL employs self-supervised learning objectives that encourage the model to learn meaningful representations based on the graphβs topology. This includes maximizing similarity between nodes that are closely connected while minimizing it for distant nodes.
Optimization and Training: The learning process is guided by an optimization algorithm that refines the embeddings iteratively. The absence of augmentation allows for a more straightforward training loop, reducing the complexity involved in adjusting to varying data conditions.
API Reference in PyG-SSLο
- class AFGRLEncoder(in_channel, hidden_channels, **kwargs)ο
A graph encoder that generate representations given
torch_geometric.data.Dataset.Parameters:ο
- in_channel (int, optional):
Number of input features of the input dataset.
- hidden_channels (int, optional):
Number of hidden channels for the encoder.
- kwargs (optional):
Additional arguments for
pygssl.methods.AFGRLEncoder.
- class AFGRL(student_encoder: torch.nn.Module, teacher_encoder: torch.nn.Module, data_augment=None, adj_ori=None, topk=8)ο
The Augmentation-Free Graph Representation Learning Algorithm.
Parameters:ο
- student_encoder (Optional[
torch.nn.Module]): The student encoder to be trained.
- student_encoder (Optional[
- teacher_encoder (Optional[
torch.nn.Module]): The teacher encoder to be trained.
- teacher_encoder (Optional[
- data_augment (Optional[
torch.nn.Module]): The augmentation function to be used. Default is None.
- data_augment (Optional[
- adj_ori (Optional[
torch.Tensor]): The original adjacency matrix. Default is None.
- adj_ori (Optional[
- topk (int):
Number of topk to be used. Default is 8.
Referencesο
Augmentation-Free Self-Supervised Learning on Graphs, Namkyeong Lee et al. Available at https://arxiv.org/abs/2112.02472
ReGCLο
Rethinking Message Passing in Graph Contrastive Learning (ReGCL) is an innovative framework that seeks to enhance the effectiveness of graph contrastive learning by re-evaluating the message-passing mechanism traditionally used in graph neural networks. As the demand for sophisticated methods in graph representation learning increases, ReGCL provides a fresh perspective on how message passing can be optimized to improve the learning process, ultimately leading to better node representations.
Introductionο
The foundation of many graph neural network (GNN) architectures lies in the message-passing paradigm, which facilitates the aggregation of information from neighboring nodes to inform the representation of each node. While effective, standard message-passing approaches may overlook critical aspects of the graph structure and relationships, particularly in the context of contrastive learning. ReGCL addresses this limitation by rethinking the message-passing process, focusing on enhancing the flow of information and its relevance to the contrastive learning task.
Key Conceptsο
Enhanced Message Passing: ReGCL introduces a refined message-passing mechanism that optimizes how information is exchanged among nodes. This is achieved by considering not only the immediate neighbors but also the broader structural context, leading to richer and more meaningful node embeddings.
Contrastive Learning Integration: The framework effectively integrates the redefined message-passing strategy into the contrastive learning paradigm. By ensuring that messages reflect the intrinsic relationships between nodes, ReGCL promotes the learning of discriminative features that improve representation quality.
Scalability and Flexibility: Designed to be scalable, ReGCL can efficiently handle large graphs while remaining flexible enough to be adapted to various applications. This is crucial in real-world scenarios where graph data can be extensive and complex.
Methodologyο
ReGCL follows a systematic approach to improve graph contrastive learning:
Message-Passing Redefinition: The first step involves redefining the message-passing process to enhance the flow of information between nodes. This includes incorporating mechanisms that prioritize relevant structural features.
Self-Supervised Objectives: ReGCL employs self-supervised learning objectives that guide the model in distinguishing between similar and dissimilar node representations. This is achieved through contrastive losses that leverage the enhanced messages.
Training and Optimization: The training process involves optimizing the model using efficient algorithms that capitalize on the improved message-passing framework. This leads to refined embeddings that better capture the underlying graph structure.
API Reference in PyG-SSLο
- class ReGCLEncoder(in_channels: int, out_channels: int, activation, mode: int = 1, base_model=GCNConv, k: int = 2, cutrate: float = 0.2, cutway: int = 1, tau: float = 0.5)ο
A ReGCL encoder that generate representations given
torch_geometric.data.Dataset.Parameters:ο
- in_channels (int, optional):
Number of input features of the input dataset.
- out_channels (int, optional):
Number of output features of the encoder.
- activation (torch.nn.Module, optional):
Activation function for the encoder.
- mode (int, optional):
Mode for the encoder. Default is 1.
- base_model (torch.nn.Module, optional):
Base GNN model for the encoder. Default is GCNConv.
- k (int, optional):
Number of layers for the encoder. Default is 2.
- cutrate (float, optional):
Cut rate for the encoder. Default is 0.2.
- cutway (int, optional):
Cut way for the encoder. Default is 1.
- class ReGCL(config, encoder: ReGCLEncoder, num_hidden: int, num_proj_hidden: int, mode: int = 1, tau: float = 0.5)ο
The ReGCL Algorithm.
Parameters:ο
- config (dict):
Configuration parameters for the encoder.
- encoder (Optional[
torch.nn.Module]): The encoder to be trained.
- encoder (Optional[
- num_hidden (int):
Output dimension of the encoder.
- num_proj_hidden (int):
Number of hidden channels for the projection head.
- mode (int, optional):
Mode for the encoder. Default is 1.
- tau (float):
Temperature parameter for the contrastive loss. Default is 0.5.
Referencesο
ReGCL: Rethinking Message Passing in Graph Contrastive Learning, Cheng Ji et al. Available at https://ojs.aaai.org/index.php/AAAI/article/view/28698