@@ -181,6 +181,7 @@ def to_spektral_graphs(self, include_object_ids: bool = False) -> List[Graph]:
181181 y = d ["y" ],
182182 id = d ["id" ],
183183 frame_id = d ["frame_id" ],
184+ ball_owning_team_id = d .get ("ball_owning_team_id" , None ),
184185 ** ({"object_ids" : d ["object_ids" ]} if include_object_ids else {}),
185186 )
186187 for d in self .graph_frames
@@ -216,12 +217,19 @@ def to_pickle(
216217 with gzip .open (file_path , "wb" ) as file :
217218 pickle .dump (self .graph_frames , file )
218219
219- def to_custom_dataset (self ) -> GraphDataset :
220+ def to_custom_dataset (self , include_object_ids : bool = False ) -> GraphDataset :
220221 """
221222 Spektral requires a spektral Dataset to load the data
222223 for docs see https://graphneural.network/creating-dataset/
223224 """
224- return GraphDataset (graphs = self .to_spektral_graphs ())
225+ return GraphDataset (graphs = self .to_spektral_graphs (include_object_ids ))
226+
227+ def to_graph_dataset (self , include_object_ids : bool = False ) -> GraphDataset :
228+ """
229+ Spektral requires a spektral Dataset to load the data
230+ for docs see https://graphneural.network/creating-dataset/
231+ """
232+ return GraphDataset (graphs = self .to_spektral_graphs (include_object_ids ))
225233
226234 def _verify_feature_funcs (self , funcs , feature_type : Literal ["edge" , "node" ]):
227235 for i , func in enumerate (funcs ):
@@ -266,6 +274,10 @@ def return_dtypes(self):
266274
267275 def to_graph_frames (self , include_object_ids : bool = False ) -> List [dict ]:
268276 def process_chunk (chunk : pl .DataFrame ) -> List [dict ]:
277+ def __convert_object_ids (objects ):
278+ # convert padded players to None
279+ return [x if x != "" else None for x in objects ]
280+
269281 return [
270282 {
271283 ** {
@@ -285,9 +297,18 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]:
285297 "y" : np .asarray ([chunk [self .label_column ][i ]]),
286298 "id" : chunk [self .graph_id_column ][i ],
287299 "frame_id" : chunk ["frame_id" ][i ],
300+ "ball_owning_team_id" : (
301+ chunk ["ball_owning_team_id" ][i ]
302+ if "ball_owning_team_id" in chunk .columns
303+ else None
304+ ),
288305 },
289306 ** (
290- {"object_ids" : list (chunk ["object_ids" ][i ][0 ])}
307+ {
308+ "object_ids" : __convert_object_ids (
309+ list (chunk ["object_ids" ][i ][0 ])
310+ )
311+ }
291312 if include_object_ids
292313 else {}
293314 ),
0 commit comments