@@ -161,6 +161,15 @@ def _shuffle(self):
161161 elif self .settings .random_seed == True :
162162 self .dataset = self .dataset .sample (fraction = 1.0 )
163163 else :
164+
165+ sort_expr = (pl .col (Column .TEAM_ID ) == Constant .BALL ).cast (int ) * 2 - (
166+ (pl .col (Column .BALL_OWNING_TEAM_ID ) == pl .col (Column .TEAM_ID ))
167+ & (pl .col (Column .TEAM_ID ) != Constant .BALL )
168+ ).cast (int )
169+
170+ self .dataset = self .dataset .sort (
171+ [* Group .BY_FRAME , sort_expr , pl .col (Column .OBJECT_ID )]
172+ )
164173 self .dataset = self .dataset .sort (Group .BY_FRAME + [Column .OBJECT_ID ])
165174
166175 def _remove_incomplete_frames (self ) -> pl .DataFrame :
@@ -185,11 +194,11 @@ def _apply_padding(self) -> pl.DataFrame:
185194 keep_columns = [
186195 Column .TIMESTAMP ,
187196 Column .BALL_STATE ,
188- Column .POSITION_NAME ,
189197 self .label_column ,
190198 self .graph_id_column ,
191199 ]
192200 empty_columns = [
201+ Column .POSITION_NAME ,
193202 Column .OBJECT_ID ,
194203 Column .IS_BALL_CARRIER ,
195204 Column .X ,
@@ -223,7 +232,11 @@ def _apply_padding(self) -> pl.DataFrame:
223232 ]
224233
225234 counts = df .group_by (group_by_columns ).agg (
226- pl .len ().alias ("count" ), * [pl .first (col ).alias (col ) for col in keep_columns ]
235+ pl .len ().alias ("count" ),
236+ * [
237+ pl .first (col ).alias (col )
238+ for col in keep_columns + self .global_feature_cols
239+ ],
227240 )
228241
229242 counts = counts .with_columns (
@@ -241,14 +254,67 @@ def _apply_padding(self) -> pl.DataFrame:
241254 pl .col ("count" ) < pl .col ("target_length" )
242255 ).with_columns ((pl .col ("target_length" ) - pl .col ("count" )).alias ("repeats" ))
243256
244- if len (groups_to_pad ) == 0 :
245- return df
246-
247257 padding_rows = []
258+ # This is where we pad players (missing balls get skipped because of 'target_length')
248259 for row in groups_to_pad .iter_rows (named = True ):
249- base_row = {col : row [col ] for col in keep_columns + group_by_columns }
260+ base_row = {
261+ col : row [col ]
262+ for col in keep_columns + group_by_columns + self .global_feature_cols
263+ }
250264 padding_rows .extend ([base_row ] * row ["repeats" ])
251265
266+ # Now check if there are frames without ball rows
267+ # Get all unique frames
268+ all_frames = df .select (
269+ [
270+ Column .GAME_ID ,
271+ Column .PERIOD_ID ,
272+ Column .FRAME_ID ,
273+ Column .BALL_OWNING_TEAM_ID ,
274+ ]
275+ + keep_columns
276+ + self .global_feature_cols
277+ ).unique ()
278+
279+ # Get frames that have ball rows
280+ frames_with_ball = (
281+ df .filter (pl .col (Column .TEAM_ID ) == Constant .BALL )
282+ .select ([Column .GAME_ID , Column .PERIOD_ID , Column .FRAME_ID ])
283+ .unique ()
284+ )
285+
286+ # Find frames missing ball rows
287+ frames_missing_ball = all_frames .join (
288+ frames_with_ball ,
289+ on = [Column .GAME_ID , Column .PERIOD_ID , Column .FRAME_ID ],
290+ how = "anti" ,
291+ )
292+
293+ # Create a dataframe of ball rows to add with appropriate columns
294+ if frames_missing_ball .height > 0 :
295+ # Create base rows for missing balls
296+ ball_rows_to_add = frames_missing_ball .with_columns (
297+ [
298+ pl .lit (Constant .BALL ).alias (Column .TEAM_ID ),
299+ pl .lit (Constant .BALL ).alias (Column .POSITION_NAME ),
300+ ]
301+ )
302+
303+ # Add to padding rows using same pattern as for players
304+ for row in ball_rows_to_add .iter_rows (named = True ):
305+ base_row = {
306+ col : row [col ]
307+ for col in keep_columns
308+ + group_by_columns
309+ + [Column .POSITION_NAME ]
310+ + self .global_feature_cols
311+ if col in row
312+ }
313+ padding_rows .append (base_row )
314+
315+ if len (padding_rows ) == 0 :
316+ return df
317+
252318 padding_df = pl .DataFrame (padding_rows )
253319
254320 schema = df .schema
@@ -260,7 +326,6 @@ def _apply_padding(self) -> pl.DataFrame:
260326 for col in user_defined_columns
261327 ]
262328 )
263-
264329 padding_df = padding_df .with_columns (
265330 [pl .col (col ).cast (df .schema [col ]).alias (col ) for col in group_by_columns ]
266331 )
@@ -643,7 +708,6 @@ def plot(
643708 team_color_b : str = "#0066CC" ,
644709 ball_color : str = "black" ,
645710 color_by : Literal ["ball_owning" , "static_home_away" ] = "ball_owning" ,
646- sort : bool = True ,
647711 ):
648712 """
649713 Plot tracking data as a static image or video file.
@@ -890,7 +954,7 @@ def direction_of_play_arrow(ax):
890954 arrow_dy = 15
891955
892956 if self .settings .orientation == Orientation .STATIC_HOME_AWAY :
893- if self ._ball_owning_team_id != str ( self .settings .home_team_id ) :
957+ if self ._ball_owning_team_id != self .settings .home_team_id :
894958 arrow_y = arrow_y * - 1
895959 arrow_dy = arrow_dy * - 1
896960 elif self .settings .orientation == Orientation .BALL_OWNING_TEAM :
@@ -917,7 +981,7 @@ def player_and_ball(frame_data, ax):
917981 if self ._color_by == "ball_owning" :
918982 team_id = self ._ball_owning_team_id
919983 elif self ._color_by == "static_home_away" :
920- team_id = str ( self .settings .home_team_id )
984+ team_id = self .settings .home_team_id
921985 else :
922986 raise ValueError (f"Unsupported color_by { self ._color_by } " )
923987
@@ -931,7 +995,7 @@ def player_and_ball(frame_data, ax):
931995 r [Column .X ],
932996 r [Column .Y ],
933997 )
934- is_ball = True if r [Column .TEAM_ID ] == Constant . BALL else False
998+ is_ball = True if r [Column .TEAM_ID ] == self . settings . ball_id else False
935999
9361000 if not is_ball :
9371001 if team_id is None :
@@ -978,10 +1042,7 @@ def player_and_ball(frame_data, ax):
9781042 path_effects .Normal (),
9791043 ]
9801044 )
981- ax .set_xlabel (
982- f"Label: { frame_data [self .label_column ][0 ]} - Ball Owning Team Id: { frame_data [Column .BALL_OWNING_TEAM_ID ][0 ]} " ,
983- fontsize = 22 ,
984- )
1045+ ax .set_xlabel (f"Label: { frame_data ['label' ][0 ]} " , fontsize = 22 )
9851046
9861047 def frame_plot (self , frame_data ):
9871048 self ._gs = GridSpec (
@@ -1021,12 +1082,10 @@ def frame_plot(self, frame_data):
10211082 if generate_video :
10221083 writer = animation .FFMpegWriter (fps = fps , bitrate = 1800 )
10231084
1024- if sort :
1025- df = df .sort (Group .BY_FRAME + [Column .OBJECT_ID ])
10261085 with writer .saving (self ._fig , file_path , dpi = 300 ):
1027- for group_id , frame_data in df .group_by (
1028- Group .BY_FRAME , maintain_order = True
1029- ):
1086+ for group_id , frame_data in df .sort (
1087+ Group .BY_FRAME + [ Column . OBJECT_ID ]
1088+ ). group_by ( Group . BY_FRAME , maintain_order = True ) :
10301089 self ._fig .clear ()
10311090 frame_plot (self , frame_data )
10321091 writer .grab_frame ()
0 commit comments