Skip to content

Commit 1cb24d0

Browse files
Merge pull request #58 from UnravelSports/feat/player-ids
Feat/player ids
2 parents b27042e + 63eed86 commit 1cb24d0

File tree

7 files changed

+83
-3
lines changed

7 files changed

+83
-3
lines changed

tests/test_soccer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,35 @@ def test_padding(self, spc_padding: SoccerGraphConverter):
801801
assert len(data) == 245
802802
assert isinstance(data[0], Graph)
803803

804+
def test_object_ids(self, spc_padding: SoccerGraphConverter):
805+
spektral_graphs = spc_padding.to_spektral_graphs(include_object_ids=True)
806+
807+
assert spektral_graphs[10].object_ids == [
808+
None, # padded players
809+
None,
810+
None,
811+
"10326",
812+
"1138",
813+
"11495",
814+
"12788",
815+
"5568",
816+
"5585",
817+
"6890",
818+
"7207",
819+
None,
820+
None,
821+
None,
822+
"10308",
823+
"1298",
824+
"17902",
825+
"2395",
826+
"4812",
827+
"5472",
828+
"6158",
829+
"9724",
830+
"ball",
831+
]
832+
804833
def test_conversion(self, spc_padding: SoccerGraphConverter):
805834
results_df = spc_padding._convert()
806835

unravel/soccer/dataset/kloppy_polars.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,16 @@ def __apply_settings(
649649
orientation=self.kloppy_dataset.metadata.orientation,
650650
home_team_id=home_team.team_id,
651651
away_team_id=away_team.team_id,
652+
players=[
653+
{
654+
"player_id": p.player_id,
655+
"team_id": p.team.team_id,
656+
"player": p.full_name,
657+
"team": p.team.name,
658+
"jersey_no": p.jersey_no,
659+
}
660+
for p in home_team.players + away_team.players
661+
],
652662
pitch_dimensions=pitch_dimensions,
653663
max_player_speed=self._max_player_speed,
654664
max_ball_speed=self._max_ball_speed,

unravel/soccer/graphs/graph_converter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,12 @@ def _apply_graph_settings(self):
374374
return GraphSettingsPolars(
375375
home_team_id=str(self._kloppy_settings.home_team_id),
376376
away_team_id=str(self._kloppy_settings.away_team_id),
377+
players=self._kloppy_settings.players,
378+
features={
379+
"edge": [x.__name__ for x in self.edge_feature_funcs],
380+
"node": [x.__name__ for x in self.node_feature_funcs],
381+
"global": self.global_feature_cols,
382+
},
377383
orientation=self._kloppy_settings.orientation,
378384
pitch_dimensions=self.pitch_dimensions,
379385
max_player_speed=self.settings.max_player_speed,
@@ -520,6 +526,7 @@ def _compute(self, args: List[pl.Series]) -> dict:
520526
}
521527
frame_data = self.__add_additional_kwargs(frame_data)
522528
frame_id = args[-1][0]
529+
ball_owning_team_id = frame_data[Column.BALL_OWNING_TEAM_ID][0]
523530

524531
if not np.all(
525532
frame_data[self.graph_id_column] == frame_data[self.graph_id_column][0]
@@ -602,6 +609,7 @@ def _compute(self, args: List[pl.Series]) -> dict:
602609
"object_ids": pl.Series(
603610
[frame_data[Column.OBJECT_ID].tolist()], dtype=pl.List(pl.String)
604611
),
612+
"ball_owning_team_id": ball_owning_team_id,
605613
}
606614

607615
def _convert(self):
@@ -627,6 +635,7 @@ def _convert(self):
627635
self.label_column,
628636
"frame_id",
629637
"object_ids",
638+
"ball_owning_team_id",
630639
]
631640
],
632641
*[

unravel/soccer/graphs/graph_settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ class GraphSettingsPolars(DefaultGraphSettings):
2323
pitch_dimensions: MetricPitchDimensions = field(
2424
init=False, repr=False, default_factory=MetricPitchDimensions
2525
)
26+
features: dict = field(
27+
default_factory=lambda: {
28+
"edge": [],
29+
"node": [],
30+
"global": [],
31+
},
32+
repr=False,
33+
)
34+
players: list = field(default_factory=list, repr=False)
2635

2736
def __post_init__(self):
2837
self._sport_specific_checks()

unravel/utils/objects/default_graph_converter.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def to_spektral_graphs(self, include_object_ids: bool = False) -> List[Graph]:
181181
y=d["y"],
182182
id=d["id"],
183183
frame_id=d["frame_id"],
184+
ball_owning_team_id=d.get("ball_owning_team_id", None),
184185
**({"object_ids": d["object_ids"]} if include_object_ids else {}),
185186
)
186187
for d in self.graph_frames
@@ -216,12 +217,19 @@ def to_pickle(
216217
with gzip.open(file_path, "wb") as file:
217218
pickle.dump(self.graph_frames, file)
218219

219-
def to_custom_dataset(self) -> GraphDataset:
220+
def to_custom_dataset(self, include_object_ids: bool = False) -> GraphDataset:
220221
"""
221222
Spektral requires a spektral Dataset to load the data
222223
for docs see https://graphneural.network/creating-dataset/
223224
"""
224-
return GraphDataset(graphs=self.to_spektral_graphs())
225+
return GraphDataset(graphs=self.to_spektral_graphs(include_object_ids))
226+
227+
def to_graph_dataset(self, include_object_ids: bool = False) -> GraphDataset:
228+
"""
229+
Spektral requires a spektral Dataset to load the data
230+
for docs see https://graphneural.network/creating-dataset/
231+
"""
232+
return GraphDataset(graphs=self.to_spektral_graphs(include_object_ids))
225233

226234
def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]):
227235
for i, func in enumerate(funcs):
@@ -266,6 +274,10 @@ def return_dtypes(self):
266274

267275
def to_graph_frames(self, include_object_ids: bool = False) -> List[dict]:
268276
def process_chunk(chunk: pl.DataFrame) -> List[dict]:
277+
def __convert_object_ids(objects):
278+
# convert padded players to None
279+
return [x if x != "" else None for x in objects]
280+
269281
return [
270282
{
271283
**{
@@ -285,9 +297,18 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]:
285297
"y": np.asarray([chunk[self.label_column][i]]),
286298
"id": chunk[self.graph_id_column][i],
287299
"frame_id": chunk["frame_id"][i],
300+
"ball_owning_team_id": (
301+
chunk["ball_owning_team_id"][i]
302+
if "ball_owning_team_id" in chunk.columns
303+
else None
304+
),
288305
},
289306
**(
290-
{"object_ids": list(chunk["object_ids"][i][0])}
307+
{
308+
"object_ids": __convert_object_ids(
309+
list(chunk["object_ids"][i][0])
310+
)
311+
}
291312
if include_object_ids
292313
else {}
293314
),

unravel/utils/objects/default_settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class DefaultSettings:
4242
max_player_acceleration: float = 6.0
4343
max_ball_acceleration: float = 13.5
4444
ball_carrier_threshold: float = 25.0
45+
players: list = field(default_factory=list)
4546
frame_rate: int = 25
4647

4748
def to_dict(self) -> Dict[str, Any]:

unravel/utils/objects/graph_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __convert(self, data) -> List[Graph]:
9292
id=g["id"],
9393
frame_id=g.get("frame_id", None),
9494
object_ids=g.get("object_ids", None),
95+
ball_owning_team_id=g.get("ball_owning_team_id", None),
9596
)
9697
for i, g in enumerate(data)
9798
if i % self.sample == 0

0 commit comments

Comments
 (0)