@@ -153,6 +153,17 @@ def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]):
153153 "Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. "
154154 )
155155
156+ @staticmethod
157+ def _sort (df ):
158+ sort_expr = (pl .col (Column .TEAM_ID ) == Constant .BALL ).cast (int ) * 2 - (
159+ (pl .col (Column .BALL_OWNING_TEAM_ID ) == pl .col (Column .TEAM_ID ))
160+ & (pl .col (Column .TEAM_ID ) != Constant .BALL )
161+ ).cast (int )
162+
163+ df = df .sort ([* Group .BY_FRAME , sort_expr , pl .col (Column .OBJECT_ID )])
164+ df = df .sort (Group .BY_FRAME + [Column .OBJECT_ID ])
165+ return df
166+
156167 def _shuffle (self ):
157168 if isinstance (self .settings .random_seed , int ):
158169 self .dataset = self .dataset .sample (
@@ -161,16 +172,7 @@ def _shuffle(self):
161172 elif self .settings .random_seed == True :
162173 self .dataset = self .dataset .sample (fraction = 1.0 )
163174 else :
164-
165- sort_expr = (pl .col (Column .TEAM_ID ) == Constant .BALL ).cast (int ) * 2 - (
166- (pl .col (Column .BALL_OWNING_TEAM_ID ) == pl .col (Column .TEAM_ID ))
167- & (pl .col (Column .TEAM_ID ) != Constant .BALL )
168- ).cast (int )
169-
170- self .dataset = self .dataset .sort (
171- [* Group .BY_FRAME , sort_expr , pl .col (Column .OBJECT_ID )]
172- )
173- self .dataset = self .dataset .sort (Group .BY_FRAME + [Column .OBJECT_ID ])
175+ self .dataset = self ._sort (self .dataset )
174176
175177 def _remove_incomplete_frames (self ) -> pl .DataFrame :
176178 df = self .dataset
@@ -707,6 +709,7 @@ def plot(
707709 team_color_a : str = "#CD0E61" ,
708710 team_color_b : str = "#0066CC" ,
709711 ball_color : str = "black" ,
712+ sort : bool = True ,
710713 color_by : Literal ["ball_owning" , "static_home_away" ] = "ball_owning" ,
711714 ):
712715 """
@@ -1079,13 +1082,16 @@ def frame_plot(self, frame_data):
10791082 self ._fig = plt .figure (figsize = (25 , 18 ))
10801083 self ._fig .subplots_adjust (left = 0.06 , right = 1.0 , bottom = 0.05 )
10811084
1085+ if sort :
1086+ df = self ._sort (df )
1087+
10821088 if generate_video :
10831089 writer = animation .FFMpegWriter (fps = fps , bitrate = 1800 )
10841090
10851091 with writer .saving (self ._fig , file_path , dpi = 300 ):
1086- for group_id , frame_data in df .sort (
1087- Group .BY_FRAME + [ Column . OBJECT_ID ]
1088- ). group_by ( Group . BY_FRAME , maintain_order = True ) :
1092+ for group_id , frame_data in df .group_by (
1093+ Group .BY_FRAME , maintain_order = True
1094+ ):
10891095 self ._fig .clear ()
10901096 frame_plot (self , frame_data )
10911097 writer .grab_frame ()
0 commit comments