Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/test_soccer.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,35 @@ def test_padding(self, spc_padding: SoccerGraphConverter):
assert len(data) == 245
assert isinstance(data[0], Graph)

def test_object_ids(self, spc_padding: SoccerGraphConverter):
spektral_graphs = spc_padding.to_spektral_graphs(include_object_ids=True)

assert spektral_graphs[10].object_ids == [
None, # padded players
None,
None,
"10326",
"1138",
"11495",
"12788",
"5568",
"5585",
"6890",
"7207",
None,
None,
None,
"10308",
"1298",
"17902",
"2395",
"4812",
"5472",
"6158",
"9724",
"ball",
]

def test_conversion(self, spc_padding: SoccerGraphConverter):
results_df = spc_padding._convert()

Expand Down
10 changes: 10 additions & 0 deletions unravel/soccer/dataset/kloppy_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,16 @@ def __apply_settings(
orientation=self.kloppy_dataset.metadata.orientation,
home_team_id=home_team.team_id,
away_team_id=away_team.team_id,
players=[
{
"player_id": p.player_id,
"team_id": p.team.team_id,
"player": p.full_name,
"team": p.team.name,
"jersey_no": p.jersey_no,
}
for p in home_team.players + away_team.players
],
pitch_dimensions=pitch_dimensions,
max_player_speed=self._max_player_speed,
max_ball_speed=self._max_ball_speed,
Expand Down
9 changes: 9 additions & 0 deletions unravel/soccer/graphs/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ def _apply_graph_settings(self):
return GraphSettingsPolars(
home_team_id=str(self._kloppy_settings.home_team_id),
away_team_id=str(self._kloppy_settings.away_team_id),
players=self._kloppy_settings.players,
features={
"edge": [x.__name__ for x in self.edge_feature_funcs],
"node": [x.__name__ for x in self.node_feature_funcs],
"global": self.global_feature_cols,
},
orientation=self._kloppy_settings.orientation,
pitch_dimensions=self.pitch_dimensions,
max_player_speed=self.settings.max_player_speed,
Expand Down Expand Up @@ -520,6 +526,7 @@ def _compute(self, args: List[pl.Series]) -> dict:
}
frame_data = self.__add_additional_kwargs(frame_data)
frame_id = args[-1][0]
ball_owning_team_id = frame_data[Column.BALL_OWNING_TEAM_ID][0]

if not np.all(
frame_data[self.graph_id_column] == frame_data[self.graph_id_column][0]
Expand Down Expand Up @@ -602,6 +609,7 @@ def _compute(self, args: List[pl.Series]) -> dict:
"object_ids": pl.Series(
[frame_data[Column.OBJECT_ID].tolist()], dtype=pl.List(pl.String)
),
"ball_owning_team_id": ball_owning_team_id,
}

def _convert(self):
Expand All @@ -627,6 +635,7 @@ def _convert(self):
self.label_column,
"frame_id",
"object_ids",
"ball_owning_team_id",
]
],
*[
Expand Down
9 changes: 9 additions & 0 deletions unravel/soccer/graphs/graph_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ class GraphSettingsPolars(DefaultGraphSettings):
pitch_dimensions: MetricPitchDimensions = field(
init=False, repr=False, default_factory=MetricPitchDimensions
)
features: dict = field(
default_factory=lambda: {
"edge": [],
"node": [],
"global": [],
},
repr=False,
)
players: list = field(default_factory=list, repr=False)

def __post_init__(self):
self._sport_specific_checks()
Expand Down
27 changes: 24 additions & 3 deletions unravel/utils/objects/default_graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def to_spektral_graphs(self, include_object_ids: bool = False) -> List[Graph]:
y=d["y"],
id=d["id"],
frame_id=d["frame_id"],
ball_owning_team_id=d.get("ball_owning_team_id", None),
**({"object_ids": d["object_ids"]} if include_object_ids else {}),
)
for d in self.graph_frames
Expand Down Expand Up @@ -216,12 +217,19 @@ def to_pickle(
with gzip.open(file_path, "wb") as file:
pickle.dump(self.graph_frames, file)

def to_custom_dataset(self) -> GraphDataset:
def to_custom_dataset(self, include_object_ids: bool = False) -> GraphDataset:
"""
Spektral requires a spektral Dataset to load the data
for docs see https://graphneural.network/creating-dataset/
"""
return GraphDataset(graphs=self.to_spektral_graphs())
return GraphDataset(graphs=self.to_spektral_graphs(include_object_ids))

def to_graph_dataset(self, include_object_ids: bool = False) -> GraphDataset:
"""
Spektral requires a spektral Dataset to load the data
for docs see https://graphneural.network/creating-dataset/
"""
return GraphDataset(graphs=self.to_spektral_graphs(include_object_ids))

def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]):
for i, func in enumerate(funcs):
Expand Down Expand Up @@ -266,6 +274,10 @@ def return_dtypes(self):

def to_graph_frames(self, include_object_ids: bool = False) -> List[dict]:
def process_chunk(chunk: pl.DataFrame) -> List[dict]:
def __convert_object_ids(objects):
# convert padded players to None
return [x if x != "" else None for x in objects]

return [
{
**{
Expand All @@ -285,9 +297,18 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]:
"y": np.asarray([chunk[self.label_column][i]]),
"id": chunk[self.graph_id_column][i],
"frame_id": chunk["frame_id"][i],
"ball_owning_team_id": (
chunk["ball_owning_team_id"][i]
if "ball_owning_team_id" in chunk.columns
else None
),
},
**(
{"object_ids": list(chunk["object_ids"][i][0])}
{
"object_ids": __convert_object_ids(
list(chunk["object_ids"][i][0])
)
}
if include_object_ids
else {}
),
Expand Down
1 change: 1 addition & 0 deletions unravel/utils/objects/default_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class DefaultSettings:
max_player_acceleration: float = 6.0
max_ball_acceleration: float = 13.5
ball_carrier_threshold: float = 25.0
players: list = field(default_factory=list)
frame_rate: int = 25

def to_dict(self) -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions unravel/utils/objects/graph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __convert(self, data) -> List[Graph]:
id=g["id"],
frame_id=g.get("frame_id", None),
object_ids=g.get("object_ids", None),
ball_owning_team_id=g.get("ball_owning_team_id", None),
)
for i, g in enumerate(data)
if i % self.sample == 0
Expand Down
Loading