Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 34 additions & 35 deletions tests/test_american_football.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def raw_dataset(self, coordinates: str):

@pytest.fixture
def edge_feature_values(self):
item_idx = 260
item_idx = 56

assert_values = {
"dist": 0.031333127237586675,
Expand All @@ -99,35 +99,35 @@ def edge_feature_values(self):
def adj_matrix_values(self):
return np.asarray(
[
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
]
)

@pytest.fixture
def node_feature_values(self):
item_idx = 6
item_idx = 14

assert_values = {
"x_normed": 0.6679999999999999,
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_conversion(

assert len(results_df) == 263

row_4 = results_df[4].to_dict()
row_4 = results_df.filter(pl.col("frame_id") == 484500005).to_dict()

x, x0, x1 = row_4["x"][0], row_4["x_shape_0"][0], row_4["x_shape_1"][0]
a, a0, a1 = row_4["a"][0], row_4["a_shape_0"][0], row_4["a_shape_1"][0]
Expand Down Expand Up @@ -350,22 +350,21 @@ def test_conversion(
assert e[item_idx_e][idx] == pytest.approx(
edge_feature_assert_values.get(edge_feature), abs=1e-5
)

np.testing.assert_array_equal(a, adj_matrix_values)

def test_to_graph_frames(
def test_to_graph_frames_1(
self, gnnc: AmericanFootballGraphConverter, node_feature_values
):
graph_frames = gnnc.to_graph_frames()

data = graph_frames
assert len(data) == 263
assert isinstance(data[0], dict)
assert isinstance(data[44], dict)
# note: these shape tests fail if we add more features (ie. metabolicpower)

item_idx_x, node_feature_assert_values = node_feature_values

x = data[4]["x"]
x = data[44]["x"]
assert x.shape == (23, len(node_feature_assert_values.keys()))

for idx, node_feature in enumerate(node_feature_assert_values.keys()):
Expand All @@ -392,20 +391,20 @@ def test_to_spektral_graph(

data = spektral_graphs
assert len(data) == 263
assert isinstance(data[0], Graph)
assert isinstance(data[44], Graph)

assert data[0].frame_id == 484500001
assert data[-1].frame_id == 5400039
assert data[0].frame_id == 5400045
assert data[-1].frame_id == 5400023

x = data[4].x
x = data[44].x
assert x.shape == (23, len(node_feature_assert_values.keys()))

for idx, node_feature in enumerate(node_feature_assert_values.keys()):
assert x[item_idx_x][idx] == pytest.approx(
node_feature_assert_values.get(node_feature), abs=1e-5
)

e = data[4].e
e = data[44].e
for idx, edge_feature in enumerate(edge_feature_assert_values.keys()):
assert e[item_idx_e][idx] == pytest.approx(
edge_feature_assert_values.get(edge_feature), abs=1e-5
Expand All @@ -419,7 +418,7 @@ def __are_csr_matrices_equal(mat1, mat2):
and np.array_equal(mat1.indptr, mat2.indptr)
)

a = data[4].a
a = data[44].a
assert __are_csr_matrices_equal(a, make_sparse(adj_matrix_values))

dataset = GraphDataset(graphs=spektral_graphs)
Expand Down
47 changes: 2 additions & 45 deletions tests/test_spektral.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,6 @@ def kloppy_polars_dataset(
def soccer_converter(
self, kloppy_polars_dataset: KloppyPolarsDataset
) -> SoccerGraphConverter:
# return SoccerGraphConverterDeprecated(
# dataset=kloppy_dataset,
# labels=dummy_labels(kloppy_dataset),
# graph_ids=dummy_graph_ids(kloppy_dataset),
# ball_carrier_treshold=25.0,
# max_player_speed=12.0,
# max_ball_speed=28.0,
# boundary_correction=None,
# self_loop_ball=True,
# adjacency_matrix_connect_type="ball",
# adjacency_matrix_type="split_by_team",
# label_type="binary",
# defending_team_node_value=0.0,
# non_potential_receiver_node_value=0.1,
# infer_ball_ownership=True,
# infer_goalkeepers=True,
# random_seed=42,
# pad=False,
# verbose=False,
# )
return SoccerGraphConverter(
dataset=kloppy_polars_dataset,
chunk_size=2_0000,
Expand All @@ -128,29 +108,6 @@ def soccer_converter(
def soccer_converter_preds(
self, kloppy_polars_dataset: KloppyPolarsDataset
) -> SoccerGraphConverter:
# @pytest.fixture()
# def soccer_converter_preds(
# self, kloppy_dataset: TrackingDataset
# ) -> SoccerGraphConverterDeprecated:
# return SoccerGraphConverterDeprecated(
# dataset=kloppy_dataset,
# prediction=True,
# ball_carrier_treshold=25.0,
# max_player_speed=12.0,
# max_ball_speed=28.0,
# boundary_correction=None,
# self_loop_ball=True,
# adjacency_matrix_connect_type="ball",
# adjacency_matrix_type="split_by_team",
# label_type="binary",
# defending_team_node_value=0.0,
# non_potential_receiver_node_value=0.1,
# infer_ball_ownership=True,
# infer_goalkeepers=True,
# random_seed=42,
# pad=False,
# verbose=False,
# )
return SoccerGraphConverter(
dataset=kloppy_polars_dataset,
prediction=True,
Expand Down Expand Up @@ -332,5 +289,5 @@ def test_dbd_prediction(self, bdb_converter_preds: AmericanFootballGraphConverte
{"frame_id": [x.id for x in pred_dataset], "y": preds.flatten()}
)

assert df["frame_id"].iloc[0] == "2021091300-4845"
assert df["frame_id"].iloc[-1] == "2021103108-54"
assert df["frame_id"].iloc[0] == "2021092612-54"
assert df["frame_id"].iloc[-1] == "2021092609-54"
10 changes: 10 additions & 0 deletions unravel/american_football/graphs/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def __init__(
self._sample()
self._shuffle()

@staticmethod
def _sort(df):
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
& (pl.col(Column.TEAM_ID) != Constant.BALL)
).cast(int)

df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)])
return df

def _sample(self):
if self.sample_rate is None:
return
Expand Down
6 changes: 6 additions & 0 deletions unravel/soccer/graphs/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def _exprs_variables(self):
Column.POSITION_NAME,
Column.BALL_OWNING_TEAM_ID,
Column.IS_BALL_CARRIER,
Column.OBJECT_ID,
self.graph_id_column,
self.label_column,
]
Expand Down Expand Up @@ -578,6 +579,7 @@ def _compute(self, args: List[pl.Series]) -> dict:
global_feature_type=self.global_feature_type,
**frame_data,
)

return {
"e": pl.Series(
[edge_features.tolist()], dtype=pl.List(pl.List(pl.Float64))
Expand All @@ -597,6 +599,9 @@ def _compute(self, args: List[pl.Series]) -> dict:
self.graph_id_column: frame_data[self.graph_id_column][0],
self.label_column: frame_data[self.label_column][0],
"frame_id": frame_id,
"object_ids": pl.Series(
[frame_data[Column.OBJECT_ID].tolist()], dtype=pl.List(pl.String)
),
}

def _convert(self):
Expand All @@ -621,6 +626,7 @@ def _convert(self):
self.graph_id_column,
self.label_column,
"frame_id",
"object_ids",
]
],
*[
Expand Down
59 changes: 36 additions & 23 deletions unravel/utils/objects/default_graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def __post_init__(self):
raise Exception("'verbose' should be of type boolean (bool)")

def _shuffle(self):
if self.settings.random_seed is None or self.settings.random_seed == False:
if self.settings.random_seed is None or self.settings.random_seed is False:
self.dataset = self._sort(self.dataset)
if isinstance(self.settings.random_seed, int):
elif self.settings.random_seed is True:
self.dataset = self.dataset.sample(fraction=1.0, shuffle=True)
elif isinstance(self.settings.random_seed, int):
self.dataset = self.dataset.sample(
fraction=1.0, seed=self.settings.random_seed
fraction=1.0, seed=self.settings.random_seed, shuffle=True
)
elif self.settings.random_seed == True:
self.dataset = self.dataset.sample(fraction=1.0)
else:
self.dataset = self._sort(self.dataset)

Expand All @@ -169,9 +169,9 @@ def _apply_graph_settings(self):
def _convert(self):
raise NotImplementedError()

def to_spektral_graphs(self) -> List[Graph]:
def to_spektral_graphs(self, include_object_ids: bool = False) -> List[Graph]:
if not self.graph_frames:
self.to_graph_frames()
self.to_graph_frames(include_object_ids)

return [
Graph(
Expand All @@ -181,11 +181,14 @@ def to_spektral_graphs(self) -> List[Graph]:
y=d["y"],
id=d["id"],
frame_id=d["frame_id"],
**({"object_ids": d["object_ids"]} if include_object_ids else {}),
)
for d in self.graph_frames
]

def to_pickle(self, file_path: str, verbose: bool = False) -> None:
def to_pickle(
self, file_path: str, verbose: bool = False, include_object_ids: bool = False
) -> None:
"""
We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y
To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data)
Expand All @@ -196,7 +199,7 @@ def to_pickle(self, file_path: str, verbose: bool = False) -> None:
)

if not self.graph_frames:
self.to_graph_frames()
self.to_graph_frames(include_object_ids)

if verbose:
print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...")
Expand Down Expand Up @@ -256,28 +259,38 @@ def return_dtypes(self):
"a_shape_1": pl.Int64,
self.graph_id_column: pl.String,
self.label_column: pl.Int64,
"object_ids": pl.List(pl.List(pl.String)),
# "frame_id": pl.String
}
)

def to_graph_frames(self) -> List[dict]:
def to_graph_frames(self, include_object_ids: bool = False) -> List[dict]:
def process_chunk(chunk: pl.DataFrame) -> List[dict]:
return [
{
"a": make_sparse(
reshape_from_size(
chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i]
)
**{
"a": make_sparse(
reshape_from_size(
chunk["a"][i],
chunk["a_shape_0"][i],
chunk["a_shape_1"][i],
)
),
"x": reshape_from_size(
chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i]
),
"e": reshape_from_size(
chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i]
),
"y": np.asarray([chunk[self.label_column][i]]),
"id": chunk[self.graph_id_column][i],
"frame_id": chunk["frame_id"][i],
},
**(
{"object_ids": list(chunk["object_ids"][i][0])}
if include_object_ids
else {}
),
"x": reshape_from_size(
chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i]
),
"e": reshape_from_size(
chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i]
),
"y": np.asarray([chunk[self.label_column][i]]),
"id": chunk[self.graph_id_column][i],
"frame_id": chunk["frame_id"][i],
}
for i in range(len(chunk))
]
Expand Down
1 change: 1 addition & 0 deletions unravel/utils/objects/graph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __convert(self, data) -> List[Graph]:
y=g["y"],
id=g["id"],
frame_id=g.get("frame_id", None),
object_ids=g.get("object_ids", None),
)
for i, g in enumerate(data)
if i % self.sample == 0
Expand Down
Loading