Skip to content

Commit 9a69c0e

Browse files
Merge pull request #50 from UnravelSports/feat/formations
kloppy > 3.17.0
2 parents 0f4fd2e + 7d3bf93 commit 9a69c0e

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
numpy==1.26.4
22
spektral==1.2.0
3-
kloppy==3.16.0
3+
kloppy==3.17.0
44
tensorflow>=2.14.0; platform_machine != 'arm64' or platform_system != 'Darwin'
55
tensorflow-macos>=2.14.0; platform_machine == 'arm64' and platform_system == 'Darwin'
66
keras==2.14.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def read_version():
3232
python_requires="~=3.11",
3333
install_requires=[
3434
"spektral==1.2.0",
35-
"kloppy==3.16.0",
35+
"kloppy==3.17.0",
3636
"tensorflow>=2.14.0;platform_machine != 'arm64' or platform_system != 'Darwin'",
3737
"tensorflow-macos>=2.14.0;platform_machine == 'arm64' and platform_system == 'Darwin'",
3838
"keras==2.14.0",

tests/test_soccer.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -832,19 +832,19 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
832832

833833
data = spektral_graphs
834834
assert data[0].id == "2417-1524"
835-
assert len(data) == 384
835+
assert len(data) == 383
836836
assert isinstance(data[0], Graph)
837837

838838
assert data[0].frame_id == 1524
839-
assert data[-1].frame_id == 2131
839+
assert data[-1].frame_id == 2097
840840

841841
dataset = GraphDataset(graphs=spektral_graphs)
842842
N, F, S, n_out, n = dataset.dimensions()
843843
assert N == 20
844844
assert F == 15
845845
assert S == 6
846846
assert n_out == 1
847-
assert n == 384
847+
assert n == 383
848848

849849
train, test, val = dataset.split_test_train_validation(
850850
split_train=4,
@@ -853,9 +853,9 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
853853
by_graph_id=True,
854854
random_seed=42,
855855
)
856-
assert train.n_graphs == 256
857-
assert test.n_graphs == 64
858-
assert val.n_graphs == 64
856+
assert train.n_graphs == 255
857+
assert test.n_graphs == 63
858+
assert val.n_graphs == 65
859859

860860
train, test, val = dataset.split_test_train_validation(
861861
split_train=4,
@@ -864,9 +864,9 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
864864
by_graph_id=False,
865865
random_seed=42,
866866
)
867-
assert train.n_graphs == 256
868-
assert test.n_graphs == 64
869-
assert val.n_graphs == 64
867+
assert train.n_graphs == 255
868+
assert test.n_graphs == 63
869+
assert val.n_graphs == 65
870870

871871
train, test, val = dataset.split_test_train_validation(
872872
split_train=4,
@@ -879,21 +879,21 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
879879
val_label_ratio=(1 / 2),
880880
)
881881

882-
assert train.n_graphs == 164
883-
assert test.n_graphs == 52
882+
assert train.n_graphs == 161
883+
assert test.n_graphs == 50
884884
assert val.n_graphs == 62
885885

886886
train, test = dataset.split_test_train(
887887
split_train=4, split_test=1, by_graph_id=False, random_seed=42
888888
)
889-
assert train.n_graphs == 307
889+
assert train.n_graphs == 306
890890
assert test.n_graphs == 77
891891

892892
train, test = dataset.split_test_train(
893893
split_train=4, split_test=5, by_graph_id=False, random_seed=42
894894
)
895895
assert train.n_graphs == 170
896-
assert test.n_graphs == 214
896+
assert test.n_graphs == 213
897897

898898
with pytest.raises(
899899
NotImplementedError,
@@ -929,7 +929,7 @@ def test_to_spektral_graph_level_features(
929929

930930
data = spektral_graphs
931931
assert data[5].id == "2417-1529"
932-
assert len(data) == 384
932+
assert len(data) == 383
933933
assert isinstance(data[0], Graph)
934934

935935
x = data[5].x

tests/test_spektral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def test_soccer_prediction(self, soccer_converter_preds: SoccerGraphConverter):
264264
).sort_values(by=["frame_id"])
265265

266266
assert df["frame_id"].iloc[0] == "2417-1524"
267-
assert df["frame_id"].iloc[-1] == "2417-1622"
267+
assert df["frame_id"].iloc[-1] == "2417-1621"
268268

269269
def test_bdb_training(self, bdb_converter: AmericanFootballGraphConverter):
270270
train = GraphDataset(graphs=bdb_converter.to_spektral_graphs())

0 commit comments

Comments
 (0)