Skip to content

This feature is currently in Preview.

Make predictions

After a GNN has been trained or loaded, calling .predictions(domain=Test) generates predictions on the test domain. The call returns a relationship that you assign to an attribute on the source concept — the concept whose instances are being predicted on. The conventional attribute name is predictions (as in Customer.predictions = ...), but you can choose anything:

Customer.predictions = gnn.predictions(domain=Test)

Once attached, the relationship is queryable like any other property on the concept. The set of attributes available on Customer.predictions depends on the task_type the model was trained for:

Task typeAvailable attributes
Classification (binary, multiclass, multilabel).probs, .predicted_labels
Regression.predicted_value
Link prediction (link, repeated link).rank, .scores, .predicted_<target>

In every case, filtering on .where(Customer.predictions) restricts the output to the entities that actually have a prediction attached — typically the rows in the Test domain. For a simple non-temporal classification task, the query is just:

Customer.predictions = gnn.predictions(domain=Test)
select(
Customer.customer_id,
Customer.predictions.predicted_labels,
Customer.predictions.probs,
).where(Customer.predictions).inspect()

Inspect classification predictions with time

Section titled “Inspect classification predictions with time”

For classification tasks, each predicted entity exposes the per-class probability distribution and the argmax label:

Customer.predictions = gnn.predictions(domain=Test)
select(
Customer.customer_id,
Customer.predictions.predicted_labels,
Customer.predictions.probs,
Customer.predictions["date"],
).where(Customer.predictions).inspect()

.probs is the model’s confidence over each class (a single float for binary classification, an array for multiclass and multilabel). .predicted_labels is the discrete class label the model commits to. In case your task is time sensitive — e.g. predict if a customer is going to churn the next month — you also need to get the time of each prediction, accessed via Customer.predictions["date"] as shown in the query above.

Regression tasks expose a single continuous value per entity:

Customer.predictions = gnn.predictions(domain=Test)
select(
Customer.customer_id,
Customer.predictions.predicted_value,
).where(Customer.predictions).inspect()

For link prediction and repeated link prediction, each source entity is paired with its top-k candidate destinations. The prediction relationship exposes:

  • .rank — the position of the candidate in the ranked list (1, 2, 3, …).
  • .scores — the relevance score the model assigned.
  • .predicted_<target> — a reference to the predicted target concept instance. The attribute name is derived from the target concept; for example, if the target is Product, the attribute is predicted_product.

To retrieve attributes of the predicted target, join the predicted_<target> reference to the target concept in .where():

Customer.predictions = gnn.predictions(domain=Test)
select(
Customer.customer_id,
Product.product_id,
Customer.predictions.rank,
Customer.predictions.scores,
).where(
Customer.predictions.predicted_product == Product,
).inspect()

The .where(Customer.predictions.predicted_product == Product) clause both filters to entities with predictions and binds each predicted reference to a concrete Product instance, so its columns are available in the select.

Replace .inspect() with .to_df() to materialize the query result as a pandas DataFrame.

df = (
select(
Customer.customer_id,
Customer.predictions.probs,
Customer.predictions.predicted_labels,
)
.where(Customer.predictions)
.to_df()
)

The same swap works for any of the queries above.