Skip to content

This feature is currently in Preview.

Understand GNN workflows

A GNN instance runs in one of two workflows:

  • A fit workflow trains a new model on labeled splits, then predicts (and optionally registers the trained model).
  • A load workflow retrieves an already-trained model, then predicts.

The workflow is implicit — PyRel determines which one a GNN is in from the constructor arguments. If you pass training fields (train, validation), you’re in the fit workflow. If you pass a model identifier (model_run_id or the registry parameters model_database / model_schema / model_name / version_name), you’re in the load workflow. Mixing arguments from both workflows raises an error at construction time.

Once the workflow is fixed, only certain method-call sequences are valid. The rest of this page enumerates them.

Construct the GNN with training arguments, call fit() to train, then call predictions() on the test domain. You can optionally call register_model() at the end to save the trained model to the Snowflake Model Registry:

gnn = GNN(
exp_database=...,
exp_schema=...,
graph=gnn_graph,
property_transformer=pt,
train=Train,
validation=Validation,
task_type=...,
eval_metric=...,
)
gnn.fit()
Source.predictions = gnn.predictions(domain=Test)
gnn.register_model(...) # optional

See Configure and train a GNN for the full set of fit-workflow arguments.

Construct the GNN with a model identifier, call load() to retrieve the trained model, then call predictions():

gnn = GNN(
exp_database=...,
exp_schema=...,
graph=gnn_graph,
property_transformer=pt,
source_concept=SourceConcept,
task_type=...,
has_time_column=...,
model_run_id="<model_run_id>", # or the four registry parameters
)
gnn.load()
gnn.register_model(...) # optional, only possible if pretrained model is loaded via model_run_id
Source.predictions = gnn.predictions(domain=Test)

See Use a pretrained model for the full set of load-workflow arguments.

The following call sequences raise errors:

  • Calling fit() in a load workflow. Once a GNN is constructed with load arguments, gnn.fit() raises an error. To train, construct a separate fit-mode GNN instance.
  • Calling load() in a fit workflow. Once a GNN is constructed with training arguments, gnn.load() raises an error. To load a previously trained model, construct a separate load-mode GNN instance.
  • Re-registering a model loaded from the registry. If a GNN was loaded by model_database / model_schema / model_name / version_name, calling gnn.register_model(...) raises an error — that model is already registered under those parameters. To register a different version, train and register a fresh model.
  • Calling predictions() before fit() (or load()). A GNN can only predict on a test domain after it has been fitted (fit workflow) or loaded (load workflow). Calling gnn.predictions(...) before either raises an error.
  • Calling register_model() before fit(). A model must be trained before it can be registered. Calling gnn.register_model(...) before gnn.fit() raises “Model is not fitted. Call gnn.fit() first.”
  • Calling register_model() twice on the same GNN instance. Once a GNN has been registered, it cannot be registered again — even under a different name or version. To register another version, train a fresh model and register that.

The following call sequence has no effect:

  • Calling fit() twice on the same GNN instance. A GNN can be trained at most once; the second fit() call produces a warning and has no effect. To retrain with different hyperparameters, construct a new GNN instance.