Skip to content

This feature is currently in Preview.

Use a pretrained model

When a GNN has been trained in an earlier workflow, you can skip training and go straight to prediction by loading the previously trained model. This is useful when you want to generate predictions on new test data without retraining.

There are two ways to identify a model to load:

  • By model run idmodel_run_id is the job ID of the training run that produced the model. It’s printed in the training logs and shown in the run in Snowflake Experiment Tracking (as the suffix in RUN_<job_id>, or under model_run_id in the run details). See Monitor training for more.
  • By registry name — for registered models: the database, schema, name, and version under which the model was registered. See Register a model for how to register one.

A GNN constructed with either of these is in load mode. In this mode, you cannot call .fit() — call .load() instead, then .predictions() on the test domain.

At its core, the flow is just three calls:

# Configure: pass the graph, source concept, and a model identifier.
gnn = GNN(source_concept=SourceConcept, task_type=task_type, ...)
# Fetch the model.
gnn.load()
# Predict on the test domain.
SourceConcept.predictions = gnn.predictions(domain=Test)

The sections below show full examples — one per way of identifying the model.

In load mode, the GNN doesn’t need labeled training or validation splits — those have already been consumed to produce the trained model. But it still needs the structural setup so it can extract features and traverse the graph at prediction time. Pass the following arguments to the constructor:

  • exp_database and exp_schema — where the experiment artifacts for the trained model live. In case of a registered model, these might be different from the model_database and model_schema. They should match the database and schema where the training Experiment was created.
  • source_concept — the concept whose instances will receive predictions. In the fit workflow this is inferred from train; in load mode you specify it explicitly.
  • target_concept — only for link-prediction tasks: the concept whose instances are candidate destinations.
  • task_type and has_time_column — must match the task the model was trained on.
  • graph — the graph that will be used to compute predictions on. If not given it will be the same as the one used to train the model.

Then provide either model_run_id or the four registry parameters (model_database, model_schema, model_name, version_name).

Use the model_run_id argument. Then call .load() to fetch the model and .predictions() to score the test domain:

gnn = GNN(
exp_database="EXPERIMENTS_DB",
exp_schema="EXPERIMENTS_SC",
source_concept=Customer,
task_type="binary_classification",
has_time_column=True,
model_run_id="<model_run_id>",
)
gnn.load()
Customer.predictions = gnn.predictions(domain=Test)

The Test relationship is the same kind of split fragment you’d build for the fit workflow — see Define a learning task for how to construct it.

If the model was registered with gnn.register_model(), identify it by its name and version — see Register a model for more on the registering parameters. These replace model_run_id in the constructor:

gnn = GNN(
exp_database="EXPERIMENTS_DB",
exp_schema="EXPERIMENT_SC",
source_concept=Customer,
task_type="binary_classification",
has_time_column=True,
model_database="PRODUCTION_DB",
model_schema="MODEL_REGISTRY",
model_name="my_dataset_model",
version_name="v1",
)
gnn.load()
Customer.predictions = gnn.predictions(domain=Test)

The four registry arguments are the same ones you passed to register_model() at registration time.

For link_prediction or repeated_link_prediction tasks, the model also needs to know which concept produces destination candidates. Pass it via target_concept:

gnn = GNN(
exp_database="EXPERIMENTS_DB",
exp_schema="MODEL_REGISTRY",
source_concept=Customer,
target_concept=Product,
task_type="repeated_link_prediction",
has_time_column=True,
model_run_id="<model_run_id>",
)
gnn.load()
Customer.predictions = gnn.predictions(domain=Test)

The prediction relationship attached to the source concept exposes different attributes depending on the task type — for link prediction tasks it carries .rank, .scores, and a .predicted_<target> reference. See Make predictions for how to query them.

For the full list of valid and invalid GNN method-call sequences, see Understand GNN workflows.