Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy==1.26.4
spektral==1.2.0
kloppy==3.16.0
kloppy==3.17.0
tensorflow>=2.14.0; platform_machine != 'arm64' or platform_system != 'Darwin'
tensorflow-macos>=2.14.0; platform_machine == 'arm64' and platform_system == 'Darwin'
keras==2.14.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def read_version():
python_requires="~=3.11",
install_requires=[
"spektral==1.2.0",
"kloppy==3.16.0",
"kloppy==3.17.0",
"tensorflow>=2.14.0;platform_machine != 'arm64' or platform_system != 'Darwin'",
"tensorflow-macos>=2.14.0;platform_machine == 'arm64' and platform_system == 'Darwin'",
"keras==2.14.0",
Expand Down
28 changes: 14 additions & 14 deletions tests/test_soccer.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,19 +832,19 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):

data = spektral_graphs
assert data[0].id == "2417-1524"
assert len(data) == 384
assert len(data) == 383
assert isinstance(data[0], Graph)

assert data[0].frame_id == 1524
assert data[-1].frame_id == 2131
assert data[-1].frame_id == 2097

dataset = GraphDataset(graphs=spektral_graphs)
N, F, S, n_out, n = dataset.dimensions()
assert N == 20
assert F == 15
assert S == 6
assert n_out == 1
assert n == 384
assert n == 383

train, test, val = dataset.split_test_train_validation(
split_train=4,
Expand All @@ -853,9 +853,9 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
by_graph_id=True,
random_seed=42,
)
assert train.n_graphs == 256
assert test.n_graphs == 64
assert val.n_graphs == 64
assert train.n_graphs == 255
assert test.n_graphs == 63
assert val.n_graphs == 65

train, test, val = dataset.split_test_train_validation(
split_train=4,
Expand All @@ -864,9 +864,9 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
by_graph_id=False,
random_seed=42,
)
assert train.n_graphs == 256
assert test.n_graphs == 64
assert val.n_graphs == 64
assert train.n_graphs == 255
assert test.n_graphs == 63
assert val.n_graphs == 65

train, test, val = dataset.split_test_train_validation(
split_train=4,
Expand All @@ -879,21 +879,21 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter):
val_label_ratio=(1 / 2),
)

assert train.n_graphs == 164
assert test.n_graphs == 52
assert train.n_graphs == 161
assert test.n_graphs == 50
assert val.n_graphs == 62

train, test = dataset.split_test_train(
split_train=4, split_test=1, by_graph_id=False, random_seed=42
)
assert train.n_graphs == 307
assert train.n_graphs == 306
assert test.n_graphs == 77

train, test = dataset.split_test_train(
split_train=4, split_test=5, by_graph_id=False, random_seed=42
)
assert train.n_graphs == 170
assert test.n_graphs == 214
assert test.n_graphs == 213

with pytest.raises(
NotImplementedError,
Expand Down Expand Up @@ -929,7 +929,7 @@ def test_to_spektral_graph_level_features(

data = spektral_graphs
assert data[5].id == "2417-1529"
assert len(data) == 384
assert len(data) == 383
assert isinstance(data[0], Graph)

x = data[5].x
Expand Down
2 changes: 1 addition & 1 deletion tests/test_spektral.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_soccer_prediction(self, soccer_converter_preds: SoccerGraphConverter):
).sort_values(by=["frame_id"])

assert df["frame_id"].iloc[0] == "2417-1524"
assert df["frame_id"].iloc[-1] == "2417-1622"
assert df["frame_id"].iloc[-1] == "2417-1621"

def test_bdb_training(self, bdb_converter: AmericanFootballGraphConverter):
train = GraphDataset(graphs=bdb_converter.to_spektral_graphs())
Expand Down