Skip to content

Commit b27042e

Browse files
Merge pull request #57 from UnravelSports/feat/player-ids
object ids
2 parents 06ac7c7 + 6c4d58d commit b27042e

File tree

6 files changed

+89
-103
lines changed

6 files changed

+89
-103
lines changed

tests/test_american_football.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def raw_dataset(self, coordinates: str):
8080

8181
@pytest.fixture
8282
def edge_feature_values(self):
83-
item_idx = 260
83+
item_idx = 56
8484

8585
assert_values = {
8686
"dist": 0.031333127237586675,
@@ -99,35 +99,35 @@ def edge_feature_values(self):
9999
def adj_matrix_values(self):
100100
return np.asarray(
101101
[
102-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
103-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
104-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
105-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
106-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
107-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
108-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
109-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
110-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
111-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
112-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
113-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
114-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
115-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
116-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
117-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
118-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
119-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
120-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
121-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
122-
[0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1],
123-
[1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1],
102+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
103+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
124104
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
105+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
106+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
107+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
108+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
109+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
110+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
111+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
112+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
113+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
114+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
115+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
116+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
117+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
118+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
119+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
120+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
121+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
122+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
123+
[0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0],
124+
[1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1],
125125
]
126126
)
127127

128128
@pytest.fixture
129129
def node_feature_values(self):
130-
item_idx = 6
130+
item_idx = 14
131131

132132
assert_values = {
133133
"x_normed": 0.6679999999999999,
@@ -315,7 +315,7 @@ def test_conversion(
315315

316316
assert len(results_df) == 263
317317

318-
row_4 = results_df[4].to_dict()
318+
row_4 = results_df.filter(pl.col("frame_id") == 484500005).to_dict()
319319

320320
x, x0, x1 = row_4["x"][0], row_4["x_shape_0"][0], row_4["x_shape_1"][0]
321321
a, a0, a1 = row_4["a"][0], row_4["a_shape_0"][0], row_4["a_shape_1"][0]
@@ -350,22 +350,21 @@ def test_conversion(
350350
assert e[item_idx_e][idx] == pytest.approx(
351351
edge_feature_assert_values.get(edge_feature), abs=1e-5
352352
)
353-
354353
np.testing.assert_array_equal(a, adj_matrix_values)
355354

356-
def test_to_graph_frames(
355+
def test_to_graph_frames_1(
357356
self, gnnc: AmericanFootballGraphConverter, node_feature_values
358357
):
359358
graph_frames = gnnc.to_graph_frames()
360359

361360
data = graph_frames
362361
assert len(data) == 263
363-
assert isinstance(data[0], dict)
362+
assert isinstance(data[44], dict)
364363
# note: these shape tests fail if we add more features (ie. metabolicpower)
365364

366365
item_idx_x, node_feature_assert_values = node_feature_values
367366

368-
x = data[4]["x"]
367+
x = data[44]["x"]
369368
assert x.shape == (23, len(node_feature_assert_values.keys()))
370369

371370
for idx, node_feature in enumerate(node_feature_assert_values.keys()):
@@ -392,20 +391,20 @@ def test_to_spektral_graph(
392391

393392
data = spektral_graphs
394393
assert len(data) == 263
395-
assert isinstance(data[0], Graph)
394+
assert isinstance(data[44], Graph)
396395

397-
assert data[0].frame_id == 484500001
398-
assert data[-1].frame_id == 5400039
396+
assert data[0].frame_id == 5400045
397+
assert data[-1].frame_id == 5400023
399398

400-
x = data[4].x
399+
x = data[44].x
401400
assert x.shape == (23, len(node_feature_assert_values.keys()))
402401

403402
for idx, node_feature in enumerate(node_feature_assert_values.keys()):
404403
assert x[item_idx_x][idx] == pytest.approx(
405404
node_feature_assert_values.get(node_feature), abs=1e-5
406405
)
407406

408-
e = data[4].e
407+
e = data[44].e
409408
for idx, edge_feature in enumerate(edge_feature_assert_values.keys()):
410409
assert e[item_idx_e][idx] == pytest.approx(
411410
edge_feature_assert_values.get(edge_feature), abs=1e-5
@@ -419,7 +418,7 @@ def __are_csr_matrices_equal(mat1, mat2):
419418
and np.array_equal(mat1.indptr, mat2.indptr)
420419
)
421420

422-
a = data[4].a
421+
a = data[44].a
423422
assert __are_csr_matrices_equal(a, make_sparse(adj_matrix_values))
424423

425424
dataset = GraphDataset(graphs=spektral_graphs)

tests/test_spektral.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -90,26 +90,6 @@ def kloppy_polars_dataset(
9090
def soccer_converter(
9191
self, kloppy_polars_dataset: KloppyPolarsDataset
9292
) -> SoccerGraphConverter:
93-
# return SoccerGraphConverterDeprecated(
94-
# dataset=kloppy_dataset,
95-
# labels=dummy_labels(kloppy_dataset),
96-
# graph_ids=dummy_graph_ids(kloppy_dataset),
97-
# ball_carrier_treshold=25.0,
98-
# max_player_speed=12.0,
99-
# max_ball_speed=28.0,
100-
# boundary_correction=None,
101-
# self_loop_ball=True,
102-
# adjacency_matrix_connect_type="ball",
103-
# adjacency_matrix_type="split_by_team",
104-
# label_type="binary",
105-
# defending_team_node_value=0.0,
106-
# non_potential_receiver_node_value=0.1,
107-
# infer_ball_ownership=True,
108-
# infer_goalkeepers=True,
109-
# random_seed=42,
110-
# pad=False,
111-
# verbose=False,
112-
# )
11393
return SoccerGraphConverter(
11494
dataset=kloppy_polars_dataset,
11595
chunk_size=2_0000,
@@ -128,29 +108,6 @@ def soccer_converter(
128108
def soccer_converter_preds(
129109
self, kloppy_polars_dataset: KloppyPolarsDataset
130110
) -> SoccerGraphConverter:
131-
# @pytest.fixture()
132-
# def soccer_converter_preds(
133-
# self, kloppy_dataset: TrackingDataset
134-
# ) -> SoccerGraphConverterDeprecated:
135-
# return SoccerGraphConverterDeprecated(
136-
# dataset=kloppy_dataset,
137-
# prediction=True,
138-
# ball_carrier_treshold=25.0,
139-
# max_player_speed=12.0,
140-
# max_ball_speed=28.0,
141-
# boundary_correction=None,
142-
# self_loop_ball=True,
143-
# adjacency_matrix_connect_type="ball",
144-
# adjacency_matrix_type="split_by_team",
145-
# label_type="binary",
146-
# defending_team_node_value=0.0,
147-
# non_potential_receiver_node_value=0.1,
148-
# infer_ball_ownership=True,
149-
# infer_goalkeepers=True,
150-
# random_seed=42,
151-
# pad=False,
152-
# verbose=False,
153-
# )
154111
return SoccerGraphConverter(
155112
dataset=kloppy_polars_dataset,
156113
prediction=True,
@@ -332,5 +289,5 @@ def test_dbd_prediction(self, bdb_converter_preds: AmericanFootballGraphConverte
332289
{"frame_id": [x.id for x in pred_dataset], "y": preds.flatten()}
333290
)
334291

335-
assert df["frame_id"].iloc[0] == "2021091300-4845"
336-
assert df["frame_id"].iloc[-1] == "2021103108-54"
292+
assert df["frame_id"].iloc[0] == "2021092612-54"
293+
assert df["frame_id"].iloc[-1] == "2021092609-54"

unravel/american_football/graphs/graph_converter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def __init__(
7676
self._sample()
7777
self._shuffle()
7878

79+
@staticmethod
80+
def _sort(df):
81+
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
82+
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
83+
& (pl.col(Column.TEAM_ID) != Constant.BALL)
84+
).cast(int)
85+
86+
df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)])
87+
return df
88+
7989
def _sample(self):
8090
if self.sample_rate is None:
8191
return

unravel/soccer/graphs/graph_converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ def _exprs_variables(self):
441441
Column.POSITION_NAME,
442442
Column.BALL_OWNING_TEAM_ID,
443443
Column.IS_BALL_CARRIER,
444+
Column.OBJECT_ID,
444445
self.graph_id_column,
445446
self.label_column,
446447
]
@@ -578,6 +579,7 @@ def _compute(self, args: List[pl.Series]) -> dict:
578579
global_feature_type=self.global_feature_type,
579580
**frame_data,
580581
)
582+
581583
return {
582584
"e": pl.Series(
583585
[edge_features.tolist()], dtype=pl.List(pl.List(pl.Float64))
@@ -597,6 +599,9 @@ def _compute(self, args: List[pl.Series]) -> dict:
597599
self.graph_id_column: frame_data[self.graph_id_column][0],
598600
self.label_column: frame_data[self.label_column][0],
599601
"frame_id": frame_id,
602+
"object_ids": pl.Series(
603+
[frame_data[Column.OBJECT_ID].tolist()], dtype=pl.List(pl.String)
604+
),
600605
}
601606

602607
def _convert(self):
@@ -621,6 +626,7 @@ def _convert(self):
621626
self.graph_id_column,
622627
self.label_column,
623628
"frame_id",
629+
"object_ids",
624630
]
625631
],
626632
*[

unravel/utils/objects/default_graph_converter.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,14 @@ def __post_init__(self):
147147
raise Exception("'verbose' should be of type boolean (bool)")
148148

149149
def _shuffle(self):
150-
if self.settings.random_seed is None or self.settings.random_seed == False:
150+
if self.settings.random_seed is None or self.settings.random_seed is False:
151151
self.dataset = self._sort(self.dataset)
152-
if isinstance(self.settings.random_seed, int):
152+
elif self.settings.random_seed is True:
153+
self.dataset = self.dataset.sample(fraction=1.0, shuffle=True)
154+
elif isinstance(self.settings.random_seed, int):
153155
self.dataset = self.dataset.sample(
154-
fraction=1.0, seed=self.settings.random_seed
156+
fraction=1.0, seed=self.settings.random_seed, shuffle=True
155157
)
156-
elif self.settings.random_seed == True:
157-
self.dataset = self.dataset.sample(fraction=1.0)
158158
else:
159159
self.dataset = self._sort(self.dataset)
160160

@@ -169,9 +169,9 @@ def _apply_graph_settings(self):
169169
def _convert(self):
170170
raise NotImplementedError()
171171

172-
def to_spektral_graphs(self) -> List[Graph]:
172+
def to_spektral_graphs(self, include_object_ids: bool = False) -> List[Graph]:
173173
if not self.graph_frames:
174-
self.to_graph_frames()
174+
self.to_graph_frames(include_object_ids)
175175

176176
return [
177177
Graph(
@@ -181,11 +181,14 @@ def to_spektral_graphs(self) -> List[Graph]:
181181
y=d["y"],
182182
id=d["id"],
183183
frame_id=d["frame_id"],
184+
**({"object_ids": d["object_ids"]} if include_object_ids else {}),
184185
)
185186
for d in self.graph_frames
186187
]
187188

188-
def to_pickle(self, file_path: str, verbose: bool = False) -> None:
189+
def to_pickle(
190+
self, file_path: str, verbose: bool = False, include_object_ids: bool = False
191+
) -> None:
189192
"""
190193
We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y
191194
To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data)
@@ -196,7 +199,7 @@ def to_pickle(self, file_path: str, verbose: bool = False) -> None:
196199
)
197200

198201
if not self.graph_frames:
199-
self.to_graph_frames()
202+
self.to_graph_frames(include_object_ids)
200203

201204
if verbose:
202205
print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...")
@@ -256,28 +259,38 @@ def return_dtypes(self):
256259
"a_shape_1": pl.Int64,
257260
self.graph_id_column: pl.String,
258261
self.label_column: pl.Int64,
262+
"object_ids": pl.List(pl.List(pl.String)),
259263
# "frame_id": pl.String
260264
}
261265
)
262266

263-
def to_graph_frames(self) -> List[dict]:
267+
def to_graph_frames(self, include_object_ids: bool = False) -> List[dict]:
264268
def process_chunk(chunk: pl.DataFrame) -> List[dict]:
265269
return [
266270
{
267-
"a": make_sparse(
268-
reshape_from_size(
269-
chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i]
270-
)
271+
**{
272+
"a": make_sparse(
273+
reshape_from_size(
274+
chunk["a"][i],
275+
chunk["a_shape_0"][i],
276+
chunk["a_shape_1"][i],
277+
)
278+
),
279+
"x": reshape_from_size(
280+
chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i]
281+
),
282+
"e": reshape_from_size(
283+
chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i]
284+
),
285+
"y": np.asarray([chunk[self.label_column][i]]),
286+
"id": chunk[self.graph_id_column][i],
287+
"frame_id": chunk["frame_id"][i],
288+
},
289+
**(
290+
{"object_ids": list(chunk["object_ids"][i][0])}
291+
if include_object_ids
292+
else {}
271293
),
272-
"x": reshape_from_size(
273-
chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i]
274-
),
275-
"e": reshape_from_size(
276-
chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i]
277-
),
278-
"y": np.asarray([chunk[self.label_column][i]]),
279-
"id": chunk[self.graph_id_column][i],
280-
"frame_id": chunk["frame_id"][i],
281294
}
282295
for i in range(len(chunk))
283296
]

unravel/utils/objects/graph_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __convert(self, data) -> List[Graph]:
9191
y=g["y"],
9292
id=g["id"],
9393
frame_id=g.get("frame_id", None),
94+
object_ids=g.get("object_ids", None),
9495
)
9596
for i, g in enumerate(data)
9697
if i % self.sample == 0

0 commit comments

Comments
 (0)