spared.graph_operations.get_sin_cos_positional_embeddings

spared.graph_operations.get_sin_cos_positional_embeddings(graph_dict: dict, max_d_pos: int) dict[source]

Get positional encodings for a neighbor graph. This function adds a transformer-like positional encodings to each graph in a graph dict. It adds the positional encodings under the attribute ‘positional_embeddings’ for each graph.

Parameters:
  • graph_dict (dict) – A dictionary where the patch names are the keys and a pytorch geometric graphs for each one are values.

  • max_d_pos (int) – Max absolute value in the relative position matrix.

Returns:

The input graph dict with the information of positional encodings for each graph.

Return type:

dict