Skip to content

Commit fe2870e

Browse files
Merge pull request #39 from UnravelSports/feat/unravel_graph
sort
2 parents 6244a56 + 5f3467a commit fe2870e

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

unravel/soccer/graphs/graph_converter_pl.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]):
153153
"Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. "
154154
)
155155

156+
@staticmethod
157+
def _sort(df):
158+
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
159+
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
160+
& (pl.col(Column.TEAM_ID) != Constant.BALL)
161+
).cast(int)
162+
163+
df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)])
164+
df = df.sort(Group.BY_FRAME + [Column.OBJECT_ID])
165+
return df
166+
156167
def _shuffle(self):
157168
if isinstance(self.settings.random_seed, int):
158169
self.dataset = self.dataset.sample(
@@ -161,16 +172,7 @@ def _shuffle(self):
161172
elif self.settings.random_seed == True:
162173
self.dataset = self.dataset.sample(fraction=1.0)
163174
else:
164-
165-
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
166-
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
167-
& (pl.col(Column.TEAM_ID) != Constant.BALL)
168-
).cast(int)
169-
170-
self.dataset = self.dataset.sort(
171-
[*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]
172-
)
173-
self.dataset = self.dataset.sort(Group.BY_FRAME + [Column.OBJECT_ID])
175+
self.dataset = self._sort(self.dataset)
174176

175177
def _remove_incomplete_frames(self) -> pl.DataFrame:
176178
df = self.dataset
@@ -707,6 +709,7 @@ def plot(
707709
team_color_a: str = "#CD0E61",
708710
team_color_b: str = "#0066CC",
709711
ball_color: str = "black",
712+
sort: bool = True,
710713
color_by: Literal["ball_owning", "static_home_away"] = "ball_owning",
711714
):
712715
"""
@@ -1079,13 +1082,16 @@ def frame_plot(self, frame_data):
10791082
self._fig = plt.figure(figsize=(25, 18))
10801083
self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05)
10811084

1085+
if sort:
1086+
df = self._sort(df)
1087+
10821088
if generate_video:
10831089
writer = animation.FFMpegWriter(fps=fps, bitrate=1800)
10841090

10851091
with writer.saving(self._fig, file_path, dpi=300):
1086-
for group_id, frame_data in df.sort(
1087-
Group.BY_FRAME + [Column.OBJECT_ID]
1088-
).group_by(Group.BY_FRAME, maintain_order=True):
1092+
for group_id, frame_data in df.group_by(
1093+
Group.BY_FRAME, maintain_order=True
1094+
):
10891095
self._fig.clear()
10901096
frame_plot(self, frame_data)
10911097
writer.grab_frame()

0 commit comments

Comments
 (0)