Skip to content

Commit fc993bb

Browse files
author
UnravelSports [JB]
committed
plot options
1 parent 8274ae4 commit fc993bb

File tree

1 file changed

+152
-62
lines changed

1 file changed

+152
-62
lines changed

unravel/soccer/graphs/graph_converter.py

Lines changed: 152 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,10 @@ def plot(
673673
ball_color: str = "black",
674674
sort: bool = True,
675675
color_by: Literal["ball_owning", "static_home_away"] = "ball_owning",
676+
anonymous: bool = False,
677+
plot_type: Literal["pitch_only", "graph_only", "full"] = "full",
678+
show_label: bool = True,
679+
show_timestamp: bool = True,
676680
):
677681
"""
678682
Plot tracking data as a static image or video file.
@@ -703,6 +707,17 @@ def plot(
703707
Method for coloring the teams:
704708
- "ball_owning": Colors teams based on ball possession
705709
- "static_home_away": Uses static colors for home and away teams
710+
anonymous : bool, default False
711+
Whether to anonymize player labels
712+
plot_type : Literal["pitch_only", "graph_only", "full"], default "full"
713+
Type of plot to generate:
714+
- "pitch_only": Shows only the soccer pitch visualization
715+
- "graph_only": Shows only the graph features (node features, adjacency matrix, edge features)
716+
- "full": Shows both pitch and graph visualizations
717+
show_pitch_label : bool, default True
718+
Whether to show the label on the pitch visualization
719+
show_pitch_timestamp : bool, default True
720+
Whether to show the timestamp on the pitch visualization
706721
707722
Returns
708723
-------
@@ -771,6 +786,11 @@ def plot(
771786
self._team_color_b = team_color_b
772787
self._ball_color = ball_color
773788
self._color_by = color_by
789+
self._plot_type = plot_type
790+
self._show_pitch_label = show_label
791+
self._show_pitch_timestamp = show_timestamp
792+
793+
self._ball_carrier_color = "black"
774794

775795
if period_id is not None and not isinstance(period_id, int):
776796
raise TypeError("period_id should be of type integer")
@@ -810,25 +830,72 @@ def plot(
810830
if df.is_empty():
811831
raise ValueError("Selection is empty, please try different timestamp(s)")
812832

833+
def setup_gridspec():
834+
"""Setup GridSpec based on plot_type"""
835+
if self._plot_type == "pitch_only":
836+
return GridSpec(1, 1, left=0.05, right=0.95, bottom=0.05, top=0.95)
837+
elif self._plot_type == "graph_only":
838+
return GridSpec(
839+
2,
840+
2,
841+
width_ratios=[1.2, 0.8],
842+
height_ratios=[1, 1],
843+
wspace=0.2, # Increased spacing
844+
hspace=0.3, # Increased spacing
845+
left=0.08,
846+
right=0.92,
847+
bottom=0.1, # More bottom margin
848+
top=0.9,
849+
) # More top margin
850+
else: # "full"
851+
return GridSpec(
852+
2,
853+
3,
854+
width_ratios=[2, 1, 3],
855+
height_ratios=[1, 1],
856+
wspace=0.15, # Increased spacing
857+
hspace=0.1, # Increased spacing
858+
left=0.05,
859+
right=0.98, # Slightly reduced right margin
860+
bottom=0.08, # More bottom margin
861+
top=0.95,
862+
)
863+
813864
def plot_graph():
865+
"""Plot graph features (node features, adjacency matrix, edge features)"""
814866
import matplotlib.pyplot as plt
815867

816-
labels = [
817-
(
818-
self.get_player_by_id(pid)["jersey_no"]
819-
if pid != Constant.BALL
820-
else Constant.BALL
821-
)
822-
for pid in self._graph.object_ids
823-
]
868+
num_rows = self._graph.x.shape[0]
824869

825-
# Plot node features in top-left
826-
ax1 = self._fig.add_subplot(self._gs[0, 0])
870+
labels = (
871+
[
872+
(
873+
self.get_player_by_id(pid)["jersey_no"]
874+
if pid != Constant.BALL
875+
else Constant.BALL
876+
)
877+
for pid in self._graph.object_ids
878+
]
879+
if not anonymous
880+
else [str(i) for i in range(num_rows)]
881+
)
882+
883+
# Determine subplot positions based on plot_type
884+
if self._plot_type == "graph_only":
885+
node_pos = (0, 0)
886+
adj_pos = (1, 0)
887+
edge_pos = (slice(None), 1)
888+
else: # "full"
889+
node_pos = (0, 0)
890+
adj_pos = (1, 0)
891+
edge_pos = (slice(None), 1)
892+
893+
# Plot node features
894+
ax1 = self._fig.add_subplot(self._gs[node_pos])
827895
ax1.imshow(self._graph.x, aspect="auto", cmap="YlOrRd")
828896
ax1.set_xlabel(f"Node Features {self._graph.x.shape}")
829897

830898
# Set y labels to integers
831-
num_rows = self._graph.x.shape[0]
832899
ax1.set_yticks(range(num_rows))
833900
ax1.set_yticklabels(labels)
834901

@@ -837,8 +904,8 @@ def plot_graph():
837904
ax1.set_xticks(range(len(node_feature_yticklabels)))
838905
ax1.set_xticklabels(node_feature_yticklabels, rotation=45, ha="left")
839906

840-
# Plot ajacency matrix in bottom-left
841-
ax2 = self._fig.add_subplot(self._gs[1, 0])
907+
# Plot adjacency matrix
908+
ax2 = self._fig.add_subplot(self._gs[adj_pos])
842909
ax2.imshow(self._graph.a.toarray(), aspect="auto", cmap="YlOrRd")
843910
ax2.set_xlabel(f"Adjacency Matrix {self._graph.a.shape}")
844911

@@ -852,8 +919,8 @@ def plot_graph():
852919
ax2.set_xticks(range(num_cols_a))
853920
ax2.set_xticklabels(labels)
854921

855-
# Plot Edge Features on the right (spanning both rows)
856-
ax3 = self._fig.add_subplot(self._gs[:, 1])
922+
# Plot Edge Features
923+
ax3 = self._fig.add_subplot(self._gs[edge_pos])
857924

858925
_, size_a = non_zeros(self._graph.a.toarray()[0 : self._ball_carrier_idx])
859926
ball_carrier_edge_idx, num_rows_e = non_zeros(
@@ -893,6 +960,7 @@ def plot_graph():
893960
plt.colorbar(im3, ax=ax3, fraction=0.1, pad=0.2)
894961

895962
def plot_vertical_pitch(frame_data: pl.DataFrame):
963+
"""Plot the soccer pitch visualization"""
896964
try:
897965
from mplsoccer import VerticalPitch
898966
except ImportError:
@@ -901,7 +969,13 @@ def plot_vertical_pitch(frame_data: pl.DataFrame):
901969
" install it using: pip install mplsoccer"
902970
)
903971

904-
ax4 = self._fig.add_subplot(self._gs[:, 2])
972+
# Determine subplot position based on plot_type
973+
if self._plot_type == "pitch_only":
974+
pitch_pos = (0, 0)
975+
else: # "full"
976+
pitch_pos = (slice(None), 2)
977+
978+
ax4 = self._fig.add_subplot(self._gs[pitch_pos])
905979
pitch = VerticalPitch(
906980
pitch_type="secondspectrum",
907981
pitch_length=self.pitch_dimensions.pitch_length,
@@ -960,8 +1034,6 @@ def player_and_ball(frame_data, ax):
9601034
else:
9611035
raise ValueError(f"Unsupported color_by {self._color_by}")
9621036

963-
self._ball_carrier_color = None
964-
9651037
for i, r in enumerate(frame_data.iter_rows(named=True)):
9661038
v, vy, vx, y, x = (
9671039
r[Column.SPEED],
@@ -997,14 +1069,19 @@ def player_and_ball(frame_data, ax):
9971069

9981070
else:
9991071
ax.scatter(x, y, color=self._ball_color, s=250, zorder=10)
1000-
# # Text with white border
1072+
1073+
# Text with white border
10011074
text = ax.text(
10021075
x + (-1.2 if is_ball else 0.0),
10031076
y + (-1.2 if is_ball else 0.0),
10041077
(
1005-
self.get_player_by_id(r[Column.OBJECT_ID])["jersey_no"]
1006-
if r[Column.OBJECT_ID] != Constant.BALL
1007-
else Constant.BALL
1078+
(
1079+
self.get_player_by_id(r[Column.OBJECT_ID])["jersey_no"]
1080+
if r[Column.OBJECT_ID] != Constant.BALL
1081+
else Constant.BALL
1082+
)
1083+
if not anonymous
1084+
else str(i)
10081085
),
10091086
color=self._ball_color if is_ball else color,
10101087
fontsize=12,
@@ -1021,7 +1098,11 @@ def player_and_ball(frame_data, ax):
10211098
path_effects.Normal(),
10221099
]
10231100
)
1101+
1102+
# Add label and timestamp to pitch if enabled
1103+
if self._show_pitch_label:
10241104
ax.set_xlabel(f"Label: {frame_data['label'][0]}", fontsize=22)
1105+
if self._show_pitch_timestamp:
10251106
ax.set_title(self._gameclock, fontsize=22)
10261107

10271108
def frame_plot(self, frame_data):
@@ -1034,58 +1115,67 @@ def timestamp_to_gameclock(timestamp, period_id):
10341115

10351116
return f"[{period_id}] - {minutes}:{seconds:02d}:{milliseconds:03d}"
10361117

1037-
self._gs = GridSpec(
1038-
2,
1039-
3,
1040-
width_ratios=[2, 1, 3],
1041-
height_ratios=[1, 1],
1042-
wspace=0.1,
1043-
hspace=0.06,
1044-
left=0.05,
1045-
right=1.0,
1046-
bottom=0.05,
1047-
)
1118+
# Setup GridSpec based on plot_type
1119+
self._gs = setup_gridspec()
10481120

1049-
# Process the current frame
1050-
features = self._compute([frame_data[col] for col in self._exprs_variables])
1051-
a = make_sparse(
1052-
reshape_from_size(
1053-
features["a"], features["a_shape_0"], features["a_shape_1"]
1121+
# Only process graph data if we need to show graphs
1122+
if self._plot_type in ["graph_only", "full"]:
1123+
# Process the current frame
1124+
features = self._compute(
1125+
[frame_data[col] for col in self._exprs_variables]
1126+
)
1127+
a = make_sparse(
1128+
reshape_from_size(
1129+
features["a"], features["a_shape_0"], features["a_shape_1"]
1130+
)
1131+
)
1132+
x = reshape_from_size(
1133+
features["x"], features["x_shape_0"], features["x_shape_1"]
1134+
)
1135+
e = reshape_from_size(
1136+
features["e"], features["e_shape_0"], features["e_shape_1"]
1137+
)
1138+
y = np.asarray([features[self.label_column]])
1139+
1140+
self._graph = Graph(
1141+
a=a,
1142+
x=x,
1143+
e=e,
1144+
y=y,
1145+
frame_id=features["frame_id"],
1146+
object_ids=frame_data[Column.OBJECT_ID],
1147+
ball_owning_team_id=frame_data[Column.BALL_OWNING_TEAM_ID][0],
10541148
)
1055-
)
1056-
x = reshape_from_size(
1057-
features["x"], features["x_shape_0"], features["x_shape_1"]
1058-
)
1059-
e = reshape_from_size(
1060-
features["e"], features["e_shape_0"], features["e_shape_1"]
1061-
)
1062-
y = np.asarray([features[self.label_column]])
1063-
1064-
self._graph = Graph(
1065-
a=a,
1066-
x=x,
1067-
e=e,
1068-
y=y,
1069-
frame_id=features["frame_id"],
1070-
object_ids=frame_data[Column.OBJECT_ID],
1071-
ball_owning_team_id=frame_data[Column.BALL_OWNING_TEAM_ID][0],
1072-
)
10731149

1074-
self._ball_carrier_idx = np.where(
1075-
frame_data[Column.IS_BALL_CARRIER] == True
1076-
)[0][0]
1150+
self._ball_carrier_idx = np.where(
1151+
frame_data[Column.IS_BALL_CARRIER] == True
1152+
)[0][0]
1153+
10771154
self._ball_owning_team_id = list(frame_data[Column.BALL_OWNING_TEAM_ID])[0]
10781155
self._gameclock = timestamp_to_gameclock(
10791156
timestamp=list(frame_data["timestamp"])[0],
10801157
period_id=list(frame_data["period_id"])[0],
10811158
)
10821159

1083-
plot_vertical_pitch(frame_data)
1084-
plot_graph()
1160+
# Plot based on plot_type
1161+
if self._plot_type == "pitch_only":
1162+
plot_vertical_pitch(frame_data)
1163+
elif self._plot_type == "graph_only":
1164+
plot_graph()
1165+
else: # "full"
1166+
plot_vertical_pitch(frame_data)
1167+
plot_graph()
10851168

10861169
plt.tight_layout()
10871170

1088-
self._fig = plt.figure(figsize=(25, 18))
1171+
# Adjust figure size based on plot_type
1172+
if self._plot_type == "pitch_only":
1173+
self._fig = plt.figure(figsize=(8, 12))
1174+
elif self._plot_type == "graph_only":
1175+
self._fig = plt.figure(figsize=(14, 10))
1176+
else: # "full"
1177+
self._fig = plt.figure(figsize=(25, 18))
1178+
10891179
self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05)
10901180

10911181
if sort:

0 commit comments

Comments
 (0)