@@ -832,19 +832,19 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
832832
833833 data = spektral_graphs
834834 assert data [0 ].id == "2417-1524"
835- assert len (data ) == 384
835+ assert len (data ) == 383
836836 assert isinstance (data [0 ], Graph )
837837
838838 assert data [0 ].frame_id == 1524
839- assert data [- 1 ].frame_id == 2131
839+ assert data [- 1 ].frame_id == 2097
840840
841841 dataset = GraphDataset (graphs = spektral_graphs )
842842 N , F , S , n_out , n = dataset .dimensions ()
843843 assert N == 20
844844 assert F == 15
845845 assert S == 6
846846 assert n_out == 1
847- assert n == 384
847+ assert n == 383
848848
849849 train , test , val = dataset .split_test_train_validation (
850850 split_train = 4 ,
@@ -853,9 +853,9 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
853853 by_graph_id = True ,
854854 random_seed = 42 ,
855855 )
856- assert train .n_graphs == 256
857- assert test .n_graphs == 64
858- assert val .n_graphs == 64
856+ assert train .n_graphs == 255
857+ assert test .n_graphs == 63
858+ assert val .n_graphs == 65
859859
860860 train , test , val = dataset .split_test_train_validation (
861861 split_train = 4 ,
@@ -864,9 +864,9 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
864864 by_graph_id = False ,
865865 random_seed = 42 ,
866866 )
867- assert train .n_graphs == 256
868- assert test .n_graphs == 64
869- assert val .n_graphs == 64
867+ assert train .n_graphs == 255
868+ assert test .n_graphs == 63
869+ assert val .n_graphs == 65
870870
871871 train , test , val = dataset .split_test_train_validation (
872872 split_train = 4 ,
@@ -879,21 +879,21 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
879879 val_label_ratio = (1 / 2 ),
880880 )
881881
882- assert train .n_graphs == 164
883- assert test .n_graphs == 52
882+ assert train .n_graphs == 161
883+ assert test .n_graphs == 50
884884 assert val .n_graphs == 62
885885
886886 train , test = dataset .split_test_train (
887887 split_train = 4 , split_test = 1 , by_graph_id = False , random_seed = 42
888888 )
889- assert train .n_graphs == 307
889+ assert train .n_graphs == 306
890890 assert test .n_graphs == 77
891891
892892 train , test = dataset .split_test_train (
893893 split_train = 4 , split_test = 5 , by_graph_id = False , random_seed = 42
894894 )
895895 assert train .n_graphs == 170
896- assert test .n_graphs == 214
896+ assert test .n_graphs == 213
897897
898898 with pytest .raises (
899899 NotImplementedError ,
@@ -929,7 +929,7 @@ def test_to_spektral_graph_level_features(
929929
930930 data = spektral_graphs
931931 assert data [5 ].id == "2417-1529"
932- assert len (data ) == 384
932+ assert len (data ) == 383
933933 assert isinstance (data [0 ], Graph )
934934
935935 x = data [5 ].x
0 commit comments