Skip to content

This feature is currently in Preview.

Configure and Train a GNN

The GNN class accepts a number of optional hyperparameters that fine-tune training, the network architecture, and other behaviors. Each parameter has a sensible default, so you only need to set the ones relevant to your task. The sections below group them by role.

ParameterTypeDescription
devicestrDevice to perform training, inference, and feature extraction. One of "cuda" or "cpu". Default: "cuda". If your predictive reasoner instance does not have a GPU, the system automatically falls back to "cpu", so this setting is safe to leave at its default.
seedintRandom seed for reproducibility. Default: 42.
ParameterTypeDescription
n_epochsintNumber of training epochs. An epoch corresponds to a full pass over the training data. Default: 10
max_itersintMaximum number of batch iterations per epoch. If None, all batches are processed. Otherwise, limits iterations. Default: None.
train_batch_sizeintBatch size for training. Default: 128.
val_batch_sizeintBatch size for validation. Default: 128.
eval_everyintFrequency (in epochs) to evaluate on the validation set. Default: 1.
patienceintNumber of epochs without improvement before early stopping. Default: 5.
lrfloatLearning rate. Default: 0.001.
T_maxintMax iterations for cosine annealing scheduler. Defaults to n_epochs if None. Default: None.
eta_minintMinimum learning rate for cosine annealing. Default: 1e-8.
ParameterTypeDescription
label_smoothingboolWhether to apply label smoothing (for classification). Default: False.
label_smoothing_alphafloatSmoothing parameter α ∈ (0, 1). Default: 0.1.
clamp_minintSpecifies the lower bound of the model’s output distribution in percentile terms (0–100). A value of 0 means no lower percentile cutoff is applied, while higher values restrict predictions to exclude the lowest portion of the output distribution. Default: 0.
clamp_maxintSpecifies the higher bound of the model’s output distribution in percentile terms (0–100). A value of 100 means no higher percentile cutoff is applied, while lower values restrict predictions to exclude the highest portion of the output distribution. Default: 100.
ParameterTypeDescription
channelsintHidden channels for GNN, encoders, and prediction heads. Default: 128.
gnn_layersintNumber of GNN layers. Defaults to len(fanouts) if None. Default: None.
fanoutsList[int]Neighbors to sample per GNN layer. E.g., [128, 64]. Default: [128, 64].
conv_aggregationstrAggregation method for convolutions. It can be one of "mean", "max" or "sum". Default: "mean".
hetero_conv_aggregationstrAggregation across edge types in heterogeneous graphs. It can be one of "mean", "max" or "sum". Default: "sum".
gnn_normstrNormalization for GNN layers. It can be one of "batch_norm", "layer_norm" or "instance_norm". Default: "layer_norm".
ParameterTypeDescription
head_layersintNumber of Multi-Layer Perceptron (MLP) layers in the prediction head. Default: 1.
head_normstrNormalization for the MLP prediction head. It can be one of "batch_norm" or "layer_norm". Default: "batch_norm".
ParameterTypeDescription
use_temporal_encoderboolWhether to use a temporal encoding model. Default: True.
temporal_strategystrStrategy for temporal neighbor sampling. "uniform" ignores time; "last" picks most recent. Default: "uniform".
Section titled “Configure negative sampling for link prediction”
ParameterTypeDescription
num_negativeintNumber of negative samples per source node (for link prediction). Default: 10.
negative_sampling_strategystrStrategy: "random" or "degree_based". "degree_based" favors popular nodes. Default: "random".
ParameterTypeDescription
text_embedderstrText embedding model. It can be one of "model2vec-potion-base-4M" or "bert-base-distill". Default: "model2vec-potion-base-4M".
id_awarenessboolWhether to use ID-awareness embeddings. Default: False.
shallow_embeddings_listList[str]Tables to assign learnable shallow embeddings. Default: [].

Hyperparameters are passed as keyword arguments when constructing the GNN. A common pattern is to collect them in a dictionary and unpack it with **, which keeps the tuning knobs separate from the structural arguments and makes the configuration easy to tweak between runs:

train_config = {
"device": "cuda",
"n_epochs": 10,
"train_batch_size": 256,
"lr": 0.001,
"head_layers": 2,
"label_smoothing": True,
"patience": 5,
}
gnn = GNN(
exp_database="EXPERIMENTS_DB",
exp_schema="MODEL_REGISTRY",
graph=gnn_graph,
property_transformer=pt,
train=Train,
validation=Validation,
task_type="binary_classification",
eval_metric="roc_auc",
**train_config,
)
gnn.fit()

You can also pass each hyperparameter directly as a keyword argument — the dictionary is just a convenience.

Once the GNN is configured, call .fit() to start training:

gnn.fit()