Use a pretrained model
When a GNN has been trained in an earlier workflow, you can skip training and go straight to prediction by loading the previously trained model. This is useful when you want to generate predictions on new test data without retraining.
There are two ways to identify a model to load:
- By model run id —
model_run_idis the job ID of the training run that produced the model. It’s printed in the training logs and shown in the run in Snowflake Experiment Tracking (as the suffix inRUN_<job_id>, or undermodel_run_idin the run details). See Monitor training for more. - By registry name — for registered models: the database, schema, name, and version under which the model was registered. See Register a model for how to register one.
A GNN constructed with either of these is in load mode. In this mode, you cannot call .fit() — call .load() instead, then .predictions() on the test domain.
At its core, the flow is just three calls:
# Configure: pass the graph, source concept, and a model identifier.gnn = GNN(source_concept=SourceConcept, task_type=task_type, ...)
# Fetch the model.gnn.load()
# Predict on the test domain.SourceConcept.predictions = gnn.predictions(domain=Test)The sections below show full examples — one per way of identifying the model.
Configure a GNN in load mode
Section titled “Configure a GNN in load mode”In load mode, the GNN doesn’t need labeled training or validation splits — those have already been consumed to produce the trained model. But it still needs the structural setup so it can extract features and traverse the graph at prediction time. Pass the following arguments to the constructor:
exp_databaseandexp_schema— where the experiment artifacts for the trained model live. In case of a registered model, these might be different from themodel_databaseandmodel_schema. They should match the database and schema where the training Experiment was created.source_concept— the concept whose instances will receive predictions. In the fit workflow this is inferred fromtrain; in load mode you specify it explicitly.target_concept— only for link-prediction tasks: the concept whose instances are candidate destinations.task_typeandhas_time_column— must match the task the model was trained on.graph— the graph that will be used to compute predictions on. If not given it will be the same as the one used to train the model.
Then provide either model_run_id or the four registry parameters (model_database, model_schema, model_name, version_name).
Load a model by run id
Section titled “Load a model by run id”Use the model_run_id argument. Then call .load() to fetch the model and .predictions() to score the test domain:
gnn = GNN( exp_database="EXPERIMENTS_DB", exp_schema="EXPERIMENTS_SC", source_concept=Customer, task_type="binary_classification", has_time_column=True, model_run_id="<model_run_id>",)
gnn.load()
Customer.predictions = gnn.predictions(domain=Test)The Test relationship is the same kind of split fragment you’d build for the fit workflow — see Define a learning task for how to construct it.
Load a model from the registry
Section titled “Load a model from the registry”If the model was registered with gnn.register_model(), identify it by its name and version — see Register a model for more on the registering parameters. These replace model_run_id in the constructor:
gnn = GNN( exp_database="EXPERIMENTS_DB", exp_schema="EXPERIMENT_SC", source_concept=Customer, task_type="binary_classification", has_time_column=True, model_database="PRODUCTION_DB", model_schema="MODEL_REGISTRY", model_name="my_dataset_model", version_name="v1",)
gnn.load()
Customer.predictions = gnn.predictions(domain=Test)The four registry arguments are the same ones you passed to register_model() at registration time.
Predict with a link-prediction model
Section titled “Predict with a link-prediction model”For link_prediction or repeated_link_prediction tasks, the model also needs to know which concept produces destination candidates. Pass it via target_concept:
gnn = GNN( exp_database="EXPERIMENTS_DB", exp_schema="MODEL_REGISTRY", source_concept=Customer, target_concept=Product, task_type="repeated_link_prediction", has_time_column=True, model_run_id="<model_run_id>",)
gnn.load()
Customer.predictions = gnn.predictions(domain=Test)The prediction relationship attached to the source concept exposes different attributes depending on the task type — for link prediction tasks it carries .rank, .scores, and a .predicted_<target> reference. See Make predictions for how to query them.
For the full list of valid and invalid GNN method-call sequences, see Understand GNN workflows.