Make predictions
After a GNN has been trained or loaded, calling .predictions(domain=Test) generates predictions on the test domain. The call returns a relationship that you assign to an attribute on the source concept — the concept whose instances are being predicted on. The conventional attribute name is predictions (as in Customer.predictions = ...), but you can choose anything:
Customer.predictions = gnn.predictions(domain=Test)Once attached, the relationship is queryable like any other property on the concept. The set of attributes available on Customer.predictions depends on the task_type the model was trained for:
| Task type | Available attributes |
|---|---|
| Classification (binary, multiclass, multilabel) | .probs, .predicted_labels |
| Regression | .predicted_value |
| Link prediction (link, repeated link) | .rank, .scores, .predicted_<target> |
In every case, filtering on .where(Customer.predictions) restricts the output to the entities that actually have a prediction attached — typically the rows in the Test domain. For a simple non-temporal classification task, the query is just:
Customer.predictions = gnn.predictions(domain=Test)
select( Customer.customer_id, Customer.predictions.predicted_labels, Customer.predictions.probs,).where(Customer.predictions).inspect()Inspect classification predictions with time
Section titled “Inspect classification predictions with time”For classification tasks, each predicted entity exposes the per-class probability distribution and the argmax label:
Customer.predictions = gnn.predictions(domain=Test)
select( Customer.customer_id, Customer.predictions.predicted_labels, Customer.predictions.probs, Customer.predictions["date"],).where(Customer.predictions).inspect().probs is the model’s confidence over each class (a single float for binary classification, an array for multiclass and multilabel). .predicted_labels is the discrete class label the model commits to. In case your task is time sensitive — e.g. predict if a customer is going to churn the next month — you also need to get the time of each prediction, accessed via Customer.predictions["date"] as shown in the query above.
Inspect regression predictions
Section titled “Inspect regression predictions”Regression tasks expose a single continuous value per entity:
Customer.predictions = gnn.predictions(domain=Test)
select( Customer.customer_id, Customer.predictions.predicted_value,).where(Customer.predictions).inspect()Inspect link-prediction predictions
Section titled “Inspect link-prediction predictions”For link prediction and repeated link prediction, each source entity is paired with its top-k candidate destinations. The prediction relationship exposes:
.rank— the position of the candidate in the ranked list (1, 2, 3, …)..scores— the relevance score the model assigned..predicted_<target>— a reference to the predicted target concept instance. The attribute name is derived from the target concept; for example, if the target isProduct, the attribute ispredicted_product.
To retrieve attributes of the predicted target, join the predicted_<target> reference to the target concept in .where():
Customer.predictions = gnn.predictions(domain=Test)
select( Customer.customer_id, Product.product_id, Customer.predictions.rank, Customer.predictions.scores,).where( Customer.predictions.predicted_product == Product,).inspect()The .where(Customer.predictions.predicted_product == Product) clause both filters to entities with predictions and binds each predicted reference to a concrete Product instance, so its columns are available in the select.
Get predictions as a pandas DataFrame
Section titled “Get predictions as a pandas DataFrame”Replace .inspect() with .to_df() to materialize the query result as a pandas DataFrame.
df = ( select( Customer.customer_id, Customer.predictions.probs, Customer.predictions.predicted_labels, ) .where(Customer.predictions) .to_df())The same swap works for any of the queries above.