Skip to content

Commit 6244a56

Browse files
Merge pull request #38 from UnravelSports/feat/unravel_graph
fix ball padding
2 parents fea3716 + febdeff commit 6244a56

File tree

5 files changed

+83
-21
lines changed

5 files changed

+83
-21
lines changed

tests/files/plot/test-1.mp4

-3.6 KB
Binary file not shown.
-19 KB
Loading

tests/files/plot/test-png.png

-19 KB
Loading

tests/test_kloppy_polars.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,10 @@ def test_padding(self, spc_padding: SoccerGraphConverterPolars):
805805
assert 1 == 1
806806

807807
data = spektral_graphs
808-
assert len(data) == 192
808+
for graph in data:
809+
assert graph.n_nodes == 23
810+
811+
assert len(data) == 245
809812
assert isinstance(data[0], Graph)
810813

811814
def spektral_graph(self, soccer_polars_converter: SoccerGraphConverterPolars):

unravel/soccer/graphs/graph_converter_pl.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,15 @@ def _shuffle(self):
161161
elif self.settings.random_seed == True:
162162
self.dataset = self.dataset.sample(fraction=1.0)
163163
else:
164+
165+
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
166+
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
167+
& (pl.col(Column.TEAM_ID) != Constant.BALL)
168+
).cast(int)
169+
170+
self.dataset = self.dataset.sort(
171+
[*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]
172+
)
164173
self.dataset = self.dataset.sort(Group.BY_FRAME + [Column.OBJECT_ID])
165174

166175
def _remove_incomplete_frames(self) -> pl.DataFrame:
@@ -185,11 +194,11 @@ def _apply_padding(self) -> pl.DataFrame:
185194
keep_columns = [
186195
Column.TIMESTAMP,
187196
Column.BALL_STATE,
188-
Column.POSITION_NAME,
189197
self.label_column,
190198
self.graph_id_column,
191199
]
192200
empty_columns = [
201+
Column.POSITION_NAME,
193202
Column.OBJECT_ID,
194203
Column.IS_BALL_CARRIER,
195204
Column.X,
@@ -223,7 +232,11 @@ def _apply_padding(self) -> pl.DataFrame:
223232
]
224233

225234
counts = df.group_by(group_by_columns).agg(
226-
pl.len().alias("count"), *[pl.first(col).alias(col) for col in keep_columns]
235+
pl.len().alias("count"),
236+
*[
237+
pl.first(col).alias(col)
238+
for col in keep_columns + self.global_feature_cols
239+
],
227240
)
228241

229242
counts = counts.with_columns(
@@ -241,14 +254,67 @@ def _apply_padding(self) -> pl.DataFrame:
241254
pl.col("count") < pl.col("target_length")
242255
).with_columns((pl.col("target_length") - pl.col("count")).alias("repeats"))
243256

244-
if len(groups_to_pad) == 0:
245-
return df
246-
247257
padding_rows = []
258+
# This is where we pad players (missing balls get skipped because of 'target_length')
248259
for row in groups_to_pad.iter_rows(named=True):
249-
base_row = {col: row[col] for col in keep_columns + group_by_columns}
260+
base_row = {
261+
col: row[col]
262+
for col in keep_columns + group_by_columns + self.global_feature_cols
263+
}
250264
padding_rows.extend([base_row] * row["repeats"])
251265

266+
# Now check if there are frames without ball rows
267+
# Get all unique frames
268+
all_frames = df.select(
269+
[
270+
Column.GAME_ID,
271+
Column.PERIOD_ID,
272+
Column.FRAME_ID,
273+
Column.BALL_OWNING_TEAM_ID,
274+
]
275+
+ keep_columns
276+
+ self.global_feature_cols
277+
).unique()
278+
279+
# Get frames that have ball rows
280+
frames_with_ball = (
281+
df.filter(pl.col(Column.TEAM_ID) == Constant.BALL)
282+
.select([Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID])
283+
.unique()
284+
)
285+
286+
# Find frames missing ball rows
287+
frames_missing_ball = all_frames.join(
288+
frames_with_ball,
289+
on=[Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID],
290+
how="anti",
291+
)
292+
293+
# Create a dataframe of ball rows to add with appropriate columns
294+
if frames_missing_ball.height > 0:
295+
# Create base rows for missing balls
296+
ball_rows_to_add = frames_missing_ball.with_columns(
297+
[
298+
pl.lit(Constant.BALL).alias(Column.TEAM_ID),
299+
pl.lit(Constant.BALL).alias(Column.POSITION_NAME),
300+
]
301+
)
302+
303+
# Add to padding rows using same pattern as for players
304+
for row in ball_rows_to_add.iter_rows(named=True):
305+
base_row = {
306+
col: row[col]
307+
for col in keep_columns
308+
+ group_by_columns
309+
+ [Column.POSITION_NAME]
310+
+ self.global_feature_cols
311+
if col in row
312+
}
313+
padding_rows.append(base_row)
314+
315+
if len(padding_rows) == 0:
316+
return df
317+
252318
padding_df = pl.DataFrame(padding_rows)
253319

254320
schema = df.schema
@@ -260,7 +326,6 @@ def _apply_padding(self) -> pl.DataFrame:
260326
for col in user_defined_columns
261327
]
262328
)
263-
264329
padding_df = padding_df.with_columns(
265330
[pl.col(col).cast(df.schema[col]).alias(col) for col in group_by_columns]
266331
)
@@ -643,7 +708,6 @@ def plot(
643708
team_color_b: str = "#0066CC",
644709
ball_color: str = "black",
645710
color_by: Literal["ball_owning", "static_home_away"] = "ball_owning",
646-
sort: bool = True,
647711
):
648712
"""
649713
Plot tracking data as a static image or video file.
@@ -890,7 +954,7 @@ def direction_of_play_arrow(ax):
890954
arrow_dy = 15
891955

892956
if self.settings.orientation == Orientation.STATIC_HOME_AWAY:
893-
if self._ball_owning_team_id != str(self.settings.home_team_id):
957+
if self._ball_owning_team_id != self.settings.home_team_id:
894958
arrow_y = arrow_y * -1
895959
arrow_dy = arrow_dy * -1
896960
elif self.settings.orientation == Orientation.BALL_OWNING_TEAM:
@@ -917,7 +981,7 @@ def player_and_ball(frame_data, ax):
917981
if self._color_by == "ball_owning":
918982
team_id = self._ball_owning_team_id
919983
elif self._color_by == "static_home_away":
920-
team_id = str(self.settings.home_team_id)
984+
team_id = self.settings.home_team_id
921985
else:
922986
raise ValueError(f"Unsupported color_by {self._color_by}")
923987

@@ -931,7 +995,7 @@ def player_and_ball(frame_data, ax):
931995
r[Column.X],
932996
r[Column.Y],
933997
)
934-
is_ball = True if r[Column.TEAM_ID] == Constant.BALL else False
998+
is_ball = True if r[Column.TEAM_ID] == self.settings.ball_id else False
935999

9361000
if not is_ball:
9371001
if team_id is None:
@@ -978,10 +1042,7 @@ def player_and_ball(frame_data, ax):
9781042
path_effects.Normal(),
9791043
]
9801044
)
981-
ax.set_xlabel(
982-
f"Label: {frame_data[self.label_column][0]} - Ball Owning Team Id: {frame_data[Column.BALL_OWNING_TEAM_ID][0]}",
983-
fontsize=22,
984-
)
1045+
ax.set_xlabel(f"Label: {frame_data['label'][0]}", fontsize=22)
9851046

9861047
def frame_plot(self, frame_data):
9871048
self._gs = GridSpec(
@@ -1021,12 +1082,10 @@ def frame_plot(self, frame_data):
10211082
if generate_video:
10221083
writer = animation.FFMpegWriter(fps=fps, bitrate=1800)
10231084

1024-
if sort:
1025-
df = df.sort(Group.BY_FRAME + [Column.OBJECT_ID])
10261085
with writer.saving(self._fig, file_path, dpi=300):
1027-
for group_id, frame_data in df.group_by(
1028-
Group.BY_FRAME, maintain_order=True
1029-
):
1086+
for group_id, frame_data in df.sort(
1087+
Group.BY_FRAME + [Column.OBJECT_ID]
1088+
).group_by(Group.BY_FRAME, maintain_order=True):
10301089
self._fig.clear()
10311090
frame_plot(self, frame_data)
10321091
writer.grab_frame()

0 commit comments

Comments
 (0)