Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 152 additions & 62 deletions unravel/soccer/graphs/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,10 @@ def plot(
ball_color: str = "black",
sort: bool = True,
color_by: Literal["ball_owning", "static_home_away"] = "ball_owning",
anonymous: bool = False,
plot_type: Literal["pitch_only", "graph_only", "full"] = "full",
show_label: bool = True,
show_timestamp: bool = True,
):
"""
Plot tracking data as a static image or video file.
Expand Down Expand Up @@ -703,6 +707,17 @@ def plot(
Method for coloring the teams:
- "ball_owning": Colors teams based on ball possession
- "static_home_away": Uses static colors for home and away teams
anonymous : bool, default False
Whether to anonymize player labels
plot_type : Literal["pitch_only", "graph_only", "full"], default "full"
Type of plot to generate:
- "pitch_only": Shows only the soccer pitch visualization
- "graph_only": Shows only the graph features (node features, adjacency matrix, edge features)
- "full": Shows both pitch and graph visualizations
show_pitch_label : bool, default True
Whether to show the label on the pitch visualization
show_pitch_timestamp : bool, default True
Whether to show the timestamp on the pitch visualization

Returns
-------
Expand Down Expand Up @@ -771,6 +786,11 @@ def plot(
self._team_color_b = team_color_b
self._ball_color = ball_color
self._color_by = color_by
self._plot_type = plot_type
self._show_pitch_label = show_label
self._show_pitch_timestamp = show_timestamp

self._ball_carrier_color = "black"

if period_id is not None and not isinstance(period_id, int):
raise TypeError("period_id should be of type integer")
Expand Down Expand Up @@ -810,25 +830,72 @@ def plot(
if df.is_empty():
raise ValueError("Selection is empty, please try different timestamp(s)")

def setup_gridspec():
"""Setup GridSpec based on plot_type"""
if self._plot_type == "pitch_only":
return GridSpec(1, 1, left=0.05, right=0.95, bottom=0.05, top=0.95)
elif self._plot_type == "graph_only":
return GridSpec(
2,
2,
width_ratios=[1.2, 0.8],
height_ratios=[1, 1],
wspace=0.2, # Increased spacing
hspace=0.3, # Increased spacing
left=0.08,
right=0.92,
bottom=0.1, # More bottom margin
top=0.9,
) # More top margin
else: # "full"
return GridSpec(
2,
3,
width_ratios=[2, 1, 3],
height_ratios=[1, 1],
wspace=0.15, # Increased spacing
hspace=0.1, # Increased spacing
left=0.05,
right=0.98, # Slightly reduced right margin
bottom=0.08, # More bottom margin
top=0.95,
)

def plot_graph():
"""Plot graph features (node features, adjacency matrix, edge features)"""
import matplotlib.pyplot as plt

labels = [
(
self.get_player_by_id(pid)["jersey_no"]
if pid != Constant.BALL
else Constant.BALL
)
for pid in self._graph.object_ids
]
num_rows = self._graph.x.shape[0]

# Plot node features in top-left
ax1 = self._fig.add_subplot(self._gs[0, 0])
labels = (
[
(
self.get_player_by_id(pid)["jersey_no"]
if pid != Constant.BALL
else Constant.BALL
)
for pid in self._graph.object_ids
]
if not anonymous
else [str(i) for i in range(num_rows)]
)

# Determine subplot positions based on plot_type
if self._plot_type == "graph_only":
node_pos = (0, 0)
adj_pos = (1, 0)
edge_pos = (slice(None), 1)
else: # "full"
node_pos = (0, 0)
adj_pos = (1, 0)
edge_pos = (slice(None), 1)

# Plot node features
ax1 = self._fig.add_subplot(self._gs[node_pos])
ax1.imshow(self._graph.x, aspect="auto", cmap="YlOrRd")
ax1.set_xlabel(f"Node Features {self._graph.x.shape}")

# Set y labels to integers
num_rows = self._graph.x.shape[0]
ax1.set_yticks(range(num_rows))
ax1.set_yticklabels(labels)

Expand All @@ -837,8 +904,8 @@ def plot_graph():
ax1.set_xticks(range(len(node_feature_yticklabels)))
ax1.set_xticklabels(node_feature_yticklabels, rotation=45, ha="left")

# Plot ajacency matrix in bottom-left
ax2 = self._fig.add_subplot(self._gs[1, 0])
# Plot adjacency matrix
ax2 = self._fig.add_subplot(self._gs[adj_pos])
ax2.imshow(self._graph.a.toarray(), aspect="auto", cmap="YlOrRd")
ax2.set_xlabel(f"Adjacency Matrix {self._graph.a.shape}")

Expand All @@ -852,8 +919,8 @@ def plot_graph():
ax2.set_xticks(range(num_cols_a))
ax2.set_xticklabels(labels)

# Plot Edge Features on the right (spanning both rows)
ax3 = self._fig.add_subplot(self._gs[:, 1])
# Plot Edge Features
ax3 = self._fig.add_subplot(self._gs[edge_pos])

_, size_a = non_zeros(self._graph.a.toarray()[0 : self._ball_carrier_idx])
ball_carrier_edge_idx, num_rows_e = non_zeros(
Expand Down Expand Up @@ -893,6 +960,7 @@ def plot_graph():
plt.colorbar(im3, ax=ax3, fraction=0.1, pad=0.2)

def plot_vertical_pitch(frame_data: pl.DataFrame):
"""Plot the soccer pitch visualization"""
try:
from mplsoccer import VerticalPitch
except ImportError:
Expand All @@ -901,7 +969,13 @@ def plot_vertical_pitch(frame_data: pl.DataFrame):
" install it using: pip install mplsoccer"
)

ax4 = self._fig.add_subplot(self._gs[:, 2])
# Determine subplot position based on plot_type
if self._plot_type == "pitch_only":
pitch_pos = (0, 0)
else: # "full"
pitch_pos = (slice(None), 2)

ax4 = self._fig.add_subplot(self._gs[pitch_pos])
pitch = VerticalPitch(
pitch_type="secondspectrum",
pitch_length=self.pitch_dimensions.pitch_length,
Expand Down Expand Up @@ -960,8 +1034,6 @@ def player_and_ball(frame_data, ax):
else:
raise ValueError(f"Unsupported color_by {self._color_by}")

self._ball_carrier_color = None

for i, r in enumerate(frame_data.iter_rows(named=True)):
v, vy, vx, y, x = (
r[Column.SPEED],
Expand Down Expand Up @@ -997,14 +1069,19 @@ def player_and_ball(frame_data, ax):

else:
ax.scatter(x, y, color=self._ball_color, s=250, zorder=10)
# # Text with white border

# Text with white border
text = ax.text(
x + (-1.2 if is_ball else 0.0),
y + (-1.2 if is_ball else 0.0),
(
self.get_player_by_id(r[Column.OBJECT_ID])["jersey_no"]
if r[Column.OBJECT_ID] != Constant.BALL
else Constant.BALL
(
self.get_player_by_id(r[Column.OBJECT_ID])["jersey_no"]
if r[Column.OBJECT_ID] != Constant.BALL
else Constant.BALL
)
if not anonymous
else str(i)
),
color=self._ball_color if is_ball else color,
fontsize=12,
Expand All @@ -1021,7 +1098,11 @@ def player_and_ball(frame_data, ax):
path_effects.Normal(),
]
)

# Add label and timestamp to pitch if enabled
if self._show_pitch_label:
ax.set_xlabel(f"Label: {frame_data['label'][0]}", fontsize=22)
if self._show_pitch_timestamp:
ax.set_title(self._gameclock, fontsize=22)

def frame_plot(self, frame_data):
Expand All @@ -1034,58 +1115,67 @@ def timestamp_to_gameclock(timestamp, period_id):

return f"[{period_id}] - {minutes}:{seconds:02d}:{milliseconds:03d}"

self._gs = GridSpec(
2,
3,
width_ratios=[2, 1, 3],
height_ratios=[1, 1],
wspace=0.1,
hspace=0.06,
left=0.05,
right=1.0,
bottom=0.05,
)
# Setup GridSpec based on plot_type
self._gs = setup_gridspec()

# Process the current frame
features = self._compute([frame_data[col] for col in self._exprs_variables])
a = make_sparse(
reshape_from_size(
features["a"], features["a_shape_0"], features["a_shape_1"]
# Only process graph data if we need to show graphs
if self._plot_type in ["graph_only", "full"]:
# Process the current frame
features = self._compute(
[frame_data[col] for col in self._exprs_variables]
)
a = make_sparse(
reshape_from_size(
features["a"], features["a_shape_0"], features["a_shape_1"]
)
)
x = reshape_from_size(
features["x"], features["x_shape_0"], features["x_shape_1"]
)
e = reshape_from_size(
features["e"], features["e_shape_0"], features["e_shape_1"]
)
y = np.asarray([features[self.label_column]])

self._graph = Graph(
a=a,
x=x,
e=e,
y=y,
frame_id=features["frame_id"],
object_ids=frame_data[Column.OBJECT_ID],
ball_owning_team_id=frame_data[Column.BALL_OWNING_TEAM_ID][0],
)
)
x = reshape_from_size(
features["x"], features["x_shape_0"], features["x_shape_1"]
)
e = reshape_from_size(
features["e"], features["e_shape_0"], features["e_shape_1"]
)
y = np.asarray([features[self.label_column]])

self._graph = Graph(
a=a,
x=x,
e=e,
y=y,
frame_id=features["frame_id"],
object_ids=frame_data[Column.OBJECT_ID],
ball_owning_team_id=frame_data[Column.BALL_OWNING_TEAM_ID][0],
)

self._ball_carrier_idx = np.where(
frame_data[Column.IS_BALL_CARRIER] == True
)[0][0]
self._ball_carrier_idx = np.where(
frame_data[Column.IS_BALL_CARRIER] == True
)[0][0]

self._ball_owning_team_id = list(frame_data[Column.BALL_OWNING_TEAM_ID])[0]
self._gameclock = timestamp_to_gameclock(
timestamp=list(frame_data["timestamp"])[0],
period_id=list(frame_data["period_id"])[0],
)

plot_vertical_pitch(frame_data)
plot_graph()
# Plot based on plot_type
if self._plot_type == "pitch_only":
plot_vertical_pitch(frame_data)
elif self._plot_type == "graph_only":
plot_graph()
else: # "full"
plot_vertical_pitch(frame_data)
plot_graph()

plt.tight_layout()

self._fig = plt.figure(figsize=(25, 18))
# Adjust figure size based on plot_type
if self._plot_type == "pitch_only":
self._fig = plt.figure(figsize=(8, 12))
elif self._plot_type == "graph_only":
self._fig = plt.figure(figsize=(14, 10))
else: # "full"
self._fig = plt.figure(figsize=(25, 18))

self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05)

if sort:
Expand Down
Loading