@@ -673,6 +673,10 @@ def plot(
673673 ball_color : str = "black" ,
674674 sort : bool = True ,
675675 color_by : Literal ["ball_owning" , "static_home_away" ] = "ball_owning" ,
676+ anonymous : bool = False ,
677+ plot_type : Literal ["pitch_only" , "graph_only" , "full" ] = "full" ,
678+ show_label : bool = True ,
679+ show_timestamp : bool = True ,
676680 ):
677681 """
678682 Plot tracking data as a static image or video file.
@@ -703,6 +707,17 @@ def plot(
703707 Method for coloring the teams:
704708 - "ball_owning": Colors teams based on ball possession
705709 - "static_home_away": Uses static colors for home and away teams
710+ anonymous : bool, default False
711+ Whether to anonymize player labels
712+ plot_type : Literal["pitch_only", "graph_only", "full"], default "full"
713+ Type of plot to generate:
714+ - "pitch_only": Shows only the soccer pitch visualization
715+ - "graph_only": Shows only the graph features (node features, adjacency matrix, edge features)
716+ - "full": Shows both pitch and graph visualizations
717+ show_pitch_label : bool, default True
718+ Whether to show the label on the pitch visualization
719+ show_pitch_timestamp : bool, default True
720+ Whether to show the timestamp on the pitch visualization
706721
707722 Returns
708723 -------
@@ -771,6 +786,11 @@ def plot(
771786 self ._team_color_b = team_color_b
772787 self ._ball_color = ball_color
773788 self ._color_by = color_by
789+ self ._plot_type = plot_type
790+ self ._show_pitch_label = show_label
791+ self ._show_pitch_timestamp = show_timestamp
792+
793+ self ._ball_carrier_color = "black"
774794
775795 if period_id is not None and not isinstance (period_id , int ):
776796 raise TypeError ("period_id should be of type integer" )
@@ -810,25 +830,72 @@ def plot(
810830 if df .is_empty ():
811831 raise ValueError ("Selection is empty, please try different timestamp(s)" )
812832
833+ def setup_gridspec ():
834+ """Setup GridSpec based on plot_type"""
835+ if self ._plot_type == "pitch_only" :
836+ return GridSpec (1 , 1 , left = 0.05 , right = 0.95 , bottom = 0.05 , top = 0.95 )
837+ elif self ._plot_type == "graph_only" :
838+ return GridSpec (
839+ 2 ,
840+ 2 ,
841+ width_ratios = [1.2 , 0.8 ],
842+ height_ratios = [1 , 1 ],
843+ wspace = 0.2 , # Increased spacing
844+ hspace = 0.3 , # Increased spacing
845+ left = 0.08 ,
846+ right = 0.92 ,
847+ bottom = 0.1 , # More bottom margin
848+ top = 0.9 ,
849+ ) # More top margin
850+ else : # "full"
851+ return GridSpec (
852+ 2 ,
853+ 3 ,
854+ width_ratios = [2 , 1 , 3 ],
855+ height_ratios = [1 , 1 ],
856+ wspace = 0.15 , # Increased spacing
857+ hspace = 0.1 , # Increased spacing
858+ left = 0.05 ,
859+ right = 0.98 , # Slightly reduced right margin
860+ bottom = 0.08 , # More bottom margin
861+ top = 0.95 ,
862+ )
863+
813864 def plot_graph ():
865+ """Plot graph features (node features, adjacency matrix, edge features)"""
814866 import matplotlib .pyplot as plt
815867
816- labels = [
817- (
818- self .get_player_by_id (pid )["jersey_no" ]
819- if pid != Constant .BALL
820- else Constant .BALL
821- )
822- for pid in self ._graph .object_ids
823- ]
868+ num_rows = self ._graph .x .shape [0 ]
824869
825- # Plot node features in top-left
826- ax1 = self ._fig .add_subplot (self ._gs [0 , 0 ])
870+ labels = (
871+ [
872+ (
873+ self .get_player_by_id (pid )["jersey_no" ]
874+ if pid != Constant .BALL
875+ else Constant .BALL
876+ )
877+ for pid in self ._graph .object_ids
878+ ]
879+ if not anonymous
880+ else [str (i ) for i in range (num_rows )]
881+ )
882+
883+ # Determine subplot positions based on plot_type
884+ if self ._plot_type == "graph_only" :
885+ node_pos = (0 , 0 )
886+ adj_pos = (1 , 0 )
887+ edge_pos = (slice (None ), 1 )
888+ else : # "full"
889+ node_pos = (0 , 0 )
890+ adj_pos = (1 , 0 )
891+ edge_pos = (slice (None ), 1 )
892+
893+ # Plot node features
894+ ax1 = self ._fig .add_subplot (self ._gs [node_pos ])
827895 ax1 .imshow (self ._graph .x , aspect = "auto" , cmap = "YlOrRd" )
828896 ax1 .set_xlabel (f"Node Features { self ._graph .x .shape } " )
829897
830898 # Set y labels to integers
831- num_rows = self ._graph .x .shape [0 ]
832899 ax1 .set_yticks (range (num_rows ))
833900 ax1 .set_yticklabels (labels )
834901
@@ -837,8 +904,8 @@ def plot_graph():
837904 ax1 .set_xticks (range (len (node_feature_yticklabels )))
838905 ax1 .set_xticklabels (node_feature_yticklabels , rotation = 45 , ha = "left" )
839906
840- # Plot ajacency matrix in bottom-left
841- ax2 = self ._fig .add_subplot (self ._gs [1 , 0 ])
907+ # Plot adjacency matrix
908+ ax2 = self ._fig .add_subplot (self ._gs [adj_pos ])
842909 ax2 .imshow (self ._graph .a .toarray (), aspect = "auto" , cmap = "YlOrRd" )
843910 ax2 .set_xlabel (f"Adjacency Matrix { self ._graph .a .shape } " )
844911
@@ -852,8 +919,8 @@ def plot_graph():
852919 ax2 .set_xticks (range (num_cols_a ))
853920 ax2 .set_xticklabels (labels )
854921
855- # Plot Edge Features on the right (spanning both rows)
856- ax3 = self ._fig .add_subplot (self ._gs [:, 1 ])
922+ # Plot Edge Features
923+ ax3 = self ._fig .add_subplot (self ._gs [edge_pos ])
857924
858925 _ , size_a = non_zeros (self ._graph .a .toarray ()[0 : self ._ball_carrier_idx ])
859926 ball_carrier_edge_idx , num_rows_e = non_zeros (
@@ -893,6 +960,7 @@ def plot_graph():
893960 plt .colorbar (im3 , ax = ax3 , fraction = 0.1 , pad = 0.2 )
894961
895962 def plot_vertical_pitch (frame_data : pl .DataFrame ):
963+ """Plot the soccer pitch visualization"""
896964 try :
897965 from mplsoccer import VerticalPitch
898966 except ImportError :
@@ -901,7 +969,13 @@ def plot_vertical_pitch(frame_data: pl.DataFrame):
901969 " install it using: pip install mplsoccer"
902970 )
903971
904- ax4 = self ._fig .add_subplot (self ._gs [:, 2 ])
972+ # Determine subplot position based on plot_type
973+ if self ._plot_type == "pitch_only" :
974+ pitch_pos = (0 , 0 )
975+ else : # "full"
976+ pitch_pos = (slice (None ), 2 )
977+
978+ ax4 = self ._fig .add_subplot (self ._gs [pitch_pos ])
905979 pitch = VerticalPitch (
906980 pitch_type = "secondspectrum" ,
907981 pitch_length = self .pitch_dimensions .pitch_length ,
@@ -960,8 +1034,6 @@ def player_and_ball(frame_data, ax):
9601034 else :
9611035 raise ValueError (f"Unsupported color_by { self ._color_by } " )
9621036
963- self ._ball_carrier_color = None
964-
9651037 for i , r in enumerate (frame_data .iter_rows (named = True )):
9661038 v , vy , vx , y , x = (
9671039 r [Column .SPEED ],
@@ -997,14 +1069,19 @@ def player_and_ball(frame_data, ax):
9971069
9981070 else :
9991071 ax .scatter (x , y , color = self ._ball_color , s = 250 , zorder = 10 )
1000- # # Text with white border
1072+
1073+ # Text with white border
10011074 text = ax .text (
10021075 x + (- 1.2 if is_ball else 0.0 ),
10031076 y + (- 1.2 if is_ball else 0.0 ),
10041077 (
1005- self .get_player_by_id (r [Column .OBJECT_ID ])["jersey_no" ]
1006- if r [Column .OBJECT_ID ] != Constant .BALL
1007- else Constant .BALL
1078+ (
1079+ self .get_player_by_id (r [Column .OBJECT_ID ])["jersey_no" ]
1080+ if r [Column .OBJECT_ID ] != Constant .BALL
1081+ else Constant .BALL
1082+ )
1083+ if not anonymous
1084+ else str (i )
10081085 ),
10091086 color = self ._ball_color if is_ball else color ,
10101087 fontsize = 12 ,
@@ -1021,7 +1098,11 @@ def player_and_ball(frame_data, ax):
10211098 path_effects .Normal (),
10221099 ]
10231100 )
1101+
1102+ # Add label and timestamp to pitch if enabled
1103+ if self ._show_pitch_label :
10241104 ax .set_xlabel (f"Label: { frame_data ['label' ][0 ]} " , fontsize = 22 )
1105+ if self ._show_pitch_timestamp :
10251106 ax .set_title (self ._gameclock , fontsize = 22 )
10261107
10271108 def frame_plot (self , frame_data ):
@@ -1034,58 +1115,67 @@ def timestamp_to_gameclock(timestamp, period_id):
10341115
10351116 return f"[{ period_id } ] - { minutes } :{ seconds :02d} :{ milliseconds :03d} "
10361117
1037- self ._gs = GridSpec (
1038- 2 ,
1039- 3 ,
1040- width_ratios = [2 , 1 , 3 ],
1041- height_ratios = [1 , 1 ],
1042- wspace = 0.1 ,
1043- hspace = 0.06 ,
1044- left = 0.05 ,
1045- right = 1.0 ,
1046- bottom = 0.05 ,
1047- )
1118+ # Setup GridSpec based on plot_type
1119+ self ._gs = setup_gridspec ()
10481120
1049- # Process the current frame
1050- features = self ._compute ([frame_data [col ] for col in self ._exprs_variables ])
1051- a = make_sparse (
1052- reshape_from_size (
1053- features ["a" ], features ["a_shape_0" ], features ["a_shape_1" ]
1121+ # Only process graph data if we need to show graphs
1122+ if self ._plot_type in ["graph_only" , "full" ]:
1123+ # Process the current frame
1124+ features = self ._compute (
1125+ [frame_data [col ] for col in self ._exprs_variables ]
1126+ )
1127+ a = make_sparse (
1128+ reshape_from_size (
1129+ features ["a" ], features ["a_shape_0" ], features ["a_shape_1" ]
1130+ )
1131+ )
1132+ x = reshape_from_size (
1133+ features ["x" ], features ["x_shape_0" ], features ["x_shape_1" ]
1134+ )
1135+ e = reshape_from_size (
1136+ features ["e" ], features ["e_shape_0" ], features ["e_shape_1" ]
1137+ )
1138+ y = np .asarray ([features [self .label_column ]])
1139+
1140+ self ._graph = Graph (
1141+ a = a ,
1142+ x = x ,
1143+ e = e ,
1144+ y = y ,
1145+ frame_id = features ["frame_id" ],
1146+ object_ids = frame_data [Column .OBJECT_ID ],
1147+ ball_owning_team_id = frame_data [Column .BALL_OWNING_TEAM_ID ][0 ],
10541148 )
1055- )
1056- x = reshape_from_size (
1057- features ["x" ], features ["x_shape_0" ], features ["x_shape_1" ]
1058- )
1059- e = reshape_from_size (
1060- features ["e" ], features ["e_shape_0" ], features ["e_shape_1" ]
1061- )
1062- y = np .asarray ([features [self .label_column ]])
1063-
1064- self ._graph = Graph (
1065- a = a ,
1066- x = x ,
1067- e = e ,
1068- y = y ,
1069- frame_id = features ["frame_id" ],
1070- object_ids = frame_data [Column .OBJECT_ID ],
1071- ball_owning_team_id = frame_data [Column .BALL_OWNING_TEAM_ID ][0 ],
1072- )
10731149
1074- self ._ball_carrier_idx = np .where (
1075- frame_data [Column .IS_BALL_CARRIER ] == True
1076- )[0 ][0 ]
1150+ self ._ball_carrier_idx = np .where (
1151+ frame_data [Column .IS_BALL_CARRIER ] == True
1152+ )[0 ][0 ]
1153+
10771154 self ._ball_owning_team_id = list (frame_data [Column .BALL_OWNING_TEAM_ID ])[0 ]
10781155 self ._gameclock = timestamp_to_gameclock (
10791156 timestamp = list (frame_data ["timestamp" ])[0 ],
10801157 period_id = list (frame_data ["period_id" ])[0 ],
10811158 )
10821159
1083- plot_vertical_pitch (frame_data )
1084- plot_graph ()
1160+ # Plot based on plot_type
1161+ if self ._plot_type == "pitch_only" :
1162+ plot_vertical_pitch (frame_data )
1163+ elif self ._plot_type == "graph_only" :
1164+ plot_graph ()
1165+ else : # "full"
1166+ plot_vertical_pitch (frame_data )
1167+ plot_graph ()
10851168
10861169 plt .tight_layout ()
10871170
1088- self ._fig = plt .figure (figsize = (25 , 18 ))
1171+ # Adjust figure size based on plot_type
1172+ if self ._plot_type == "pitch_only" :
1173+ self ._fig = plt .figure (figsize = (8 , 12 ))
1174+ elif self ._plot_type == "graph_only" :
1175+ self ._fig = plt .figure (figsize = (14 , 10 ))
1176+ else : # "full"
1177+ self ._fig = plt .figure (figsize = (25 , 18 ))
1178+
10891179 self ._fig .subplots_adjust (left = 0.06 , right = 1.0 , bottom = 0.05 )
10901180
10911181 if sort :
0 commit comments