Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
bb7d854
added BasketballDataset class
not-heavychevy Apr 10, 2025
2abeeff
added BasketballPitchDimensions class
not-heavychevy Apr 10, 2025
bd59522
added graph settings
not-heavychevy Apr 10, 2025
8a83938
added optimized graph converter
not-heavychevy Apr 10, 2025
f5071c6
added ball handling
not-heavychevy Apr 10, 2025
26d6d85
added init files
not-heavychevy Apr 10, 2025
f2d164b
bugfix dataset load() bug
not-heavychevy Apr 10, 2025
d86c0af
added tests
not-heavychevy Apr 10, 2025
d1c0c73
added additional fields computation
not-heavychevy Apr 10, 2025
64f5ee3
BasketballDataset inherits from DefaultDataset
not-heavychevy Apr 12, 2025
835cd59
bugfix
not-heavychevy Apr 12, 2025
98f09ae
files read with kloppy.io
not-heavychevy Apr 19, 2025
0502aa7
added norm parameters
not-heavychevy Apr 19, 2025
d2f6b52
refactor: move get_dataframe to DefaultDataset
not-heavychevy Apr 20, 2025
53ea444
created post_init
not-heavychevy Apr 20, 2025
3482bf9
added self.settings to BasketballDataset
not-heavychevy Apr 20, 2025
51a6657
added add_dummy_labels и add_graph_ids
not-heavychevy Apr 20, 2025
1352f80
rewritten tests for dataset.py
not-heavychevy Apr 21, 2025
b0fc5c1
Refactor BasketballPitchDimensions
not-heavychevy Apr 25, 2025
1e04bfd
added tests for BasketballPitchDimensions
not-heavychevy Apr 25, 2025
627fae8
Refactor BasketballGraphSettings
not-heavychevy Apr 25, 2025
1bdd740
added tests for BasketballGraphSettings
not-heavychevy Apr 25, 2025
7c64156
Merge PitchDimensions and GraphSettings
not-heavychevy Apr 25, 2025
a70739c
graph_settings test update
not-heavychevy Apr 25, 2025
ebe0914
import bugs fix
not-heavychevy Apr 25, 2025
2dcd3fb
graph_converter refactoring
not-heavychevy Apr 26, 2025
4b96024
dataset separator bugfix
not-heavychevy Apr 26, 2025
af3a02a
added tests for graph_converter
not-heavychevy Apr 26, 2025
8a47337
moved the functionality to “features”
not-heavychevy Apr 26, 2025
633afca
tests update
not-heavychevy Apr 26, 2025
7463b1e
tests fix
not-heavychevy Apr 26, 2025
dcfa8e4
Deprecate speed/acceleration thresholds
not-heavychevy Apr 26, 2025
1b5bc3b
unify data/settings access on DefaultDataset
not-heavychevy Apr 26, 2025
7eb2081
Refactor _convert to use polars methods
not-heavychevy Apr 26, 2025
b0b9d72
Add unified graph-export API to GraphConverter
not-heavychevy Apr 26, 2025
e55d30e
added new tests for public export API
not-heavychevy Apr 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
513 changes: 513 additions & 0 deletions tests/test_basketball.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions unravel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from .soccer import *
from .american_football import *
from .basketball import *
from .utils import *
from .classifiers import *
6 changes: 6 additions & 0 deletions unravel/basketball/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .dataset import BasketballDataset
from .graphs import (
BasketballGraphConverter,
BasketballGraphSettings,
BasketballPitchDimensions,
)
1 change: 1 addition & 0 deletions unravel/basketball/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dataset import BasketballDataset
224 changes: 224 additions & 0 deletions unravel/basketball/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from dataclasses import dataclass, field
from typing import Optional, List
import os
import json
import tempfile
import polars as pl

from kloppy.io import open_as_file

try:
import py7zr
except ImportError:
py7zr = None

# Import default base classes
from unravel.utils import DefaultDataset, DefaultSettings
# Import pitch dimensions from merged graph_settings module
from unravel.basketball.graphs.graph_settings import BasketballPitchDimensions

@dataclass(kw_only=True)
class BasketballDataset(DefaultDataset):
"""
Loads NBA tracking data.

Modes:
- URL: Loads from a 7zip archive (expects a JSON file inside).
- Local: Loads from a file path or game identifier.

Additional parameters:
- max_player_speed, max_ball_speed, max_player_acceleration, max_ball_acceleration:
Thresholds for normalizing speed/acceleration.
- orient_ball_owning:
If True, computes oriented direction for ball ownership.
- sample_rate:
Fraction of rows to sample (0.0–1.0).
"""
tracking_data: str
max_player_speed: float = 20.0
max_ball_speed: float = 30.0
max_player_acceleration: float = 10.0
max_ball_acceleration: float = 10.0
orient_ball_owning: bool = False
sample_rate: float = 1.0
data: Optional[pl.DataFrame] = field(default=None, init=False)
settings: DefaultSettings = field(init=False)

def __post_init__(self):
# Initialize default settings
self.settings = DefaultSettings(
pitch_dimensions=BasketballPitchDimensions(),
home_team_id="home",
away_team_id="away",
provider="nba",
orientation="attacking_home"
)
# Automatically load data
self.load()

def load(self) -> pl.DataFrame:
# Determine source: URL or local file
if self.tracking_data.startswith("http"):
with open_as_file(self.tracking_data) as tmp_file:
tmp_filename = tmp_file.name
if py7zr is None:
raise ImportError("py7zr is required to extract 7zip archives.")
json_file = None
with py7zr.SevenZipFile(tmp_filename, mode='r') as archive:
for fname in archive.getnames():
if fname.endswith('.json'):
extract_path = tempfile.mkdtemp()
archive.extract(targets=[fname], path=extract_path)
json_file = os.path.join(extract_path, fname)
break
os.unlink(tmp_filename)
if json_file is None:
raise FileNotFoundError("JSON file not found in archive.")
with open(json_file, 'r', encoding='utf-8') as jf:
json_data = json.load(jf)
else:
# Local file or game ID mapping
if os.path.isfile(self.tracking_data):
file_path = self.tracking_data
else:
file_path = os.path.join(
"data", "nba", f"{self.tracking_data}.json"
)
if not os.path.isfile(file_path):
raise FileNotFoundError(
f"Game file '{self.tracking_data}.json' not found at {file_path}"
)
with open_as_file(file_path) as f:
json_data = json.load(f)

# Parse JSON into rows
rows = []
if isinstance(json_data, dict):
game_id = json_data.get("gameid", "unknown")
for event_id, event in enumerate(json_data.get("events", [])):
for m_idx, moment in enumerate(event.get("moments", [])):
if len(moment) >= 6:
quarter, _, game_clock, shot_clock, *_ , entities = moment
for entity in entities:
if len(entity) >= 4:
rows.append({
"game_id": game_id,
"event_id": event_id,
"frame_id": m_idx,
"quarter": quarter,
"game_clock": float(game_clock) if game_clock is not None else None,
"shot_clock": float(shot_clock) if shot_clock is not None else None,
"team": entity[0],
"player": entity[1],
"x": float(entity[2]),
"y": float(entity[3])
})
elif isinstance(json_data, list):
for rec in json_data:
rows.append({
"game_id": rec.get("game_id", "unknown"),
"frame_id": rec.get("frame_id"),
"team": rec.get("team"),
"player": rec.get("player"),
"x": float(rec.get("x", 0)),
"y": float(rec.get("y", 0))
})
else:
raise ValueError("Unexpected JSON structure")

# Build DataFrame and apply sampling
self.data = pl.DataFrame(rows, strict=False)
if self.sample_rate < 1.0:
self.data = self.data.sample(
fraction=self.sample_rate, with_replacement=False
)

# Add computed fields and return
self.data = self.compute_additional_fields()
return self.data

def compute_additional_fields(self) -> pl.DataFrame:
if self.data is None:
raise ValueError("Data not loaded. Call load() first.")

df = self.data.sort(["game_id", "player", "frame_id"])

# Time delta for velocity
if "game_clock" in df.columns:
df = df.with_columns(
(pl.col("game_clock").shift(-1) - pl.col("game_clock"))
.abs().fill_null(1).alias("dt")
)
else:
df = df.with_columns(pl.lit(1).alias("dt"))

# Displacements
df = df.with_columns([
(pl.col("x") - pl.col("x").shift(1)).alias("dx"),
(pl.col("y") - pl.col("y").shift(1)).alias("dy")
])

# Velocity components
df = df.with_columns([
(pl.col("dx") / pl.col("dt")).alias("vx"),
(pl.col("dy") / pl.col("dt")).alias("vy")
])

# Speed
df = df.with_columns([
((pl.col("vx")**2 + pl.col("vy")**2)**0.5).alias("speed")
])

# Time delta for acceleration
if "game_clock" in df.columns:
df = df.with_columns(
(pl.col("game_clock") - pl.col("game_clock").shift(1))
.abs().fill_null(1).alias("dt_acc")
)
else:
df = df.with_columns(pl.lit(1).alias("dt_acc"))

# Acceleration
df = df.with_columns(
((pl.col("speed") - pl.col("speed").shift(1)) / pl.col("dt_acc"))
.alias("acceleration")
)

# Normalize speed and acceleration
df = df.with_columns([
pl.when(pl.col("player").str.contains("ball", literal=False))
.then(pl.col("speed") / self.max_ball_speed)
.otherwise(pl.col("speed") / self.max_player_speed)
.clip(0, 1)
.alias("normalized_speed"),
pl.when(pl.col("player").str.contains("ball", literal=False))
.then(pl.col("acceleration") / self.max_ball_acceleration)
.otherwise(pl.col("acceleration") / self.max_player_acceleration)
.alias("normalized_acceleration")
])

# Oriented direction if requested
if self.orient_ball_owning:
df = df.with_columns(pl.lit(None).alias("oriented_direction"))

return df

def add_graph_ids(self, by: List[str], column_name: str = "graph_id") -> None:
"""
Add a graph identifier column by concatenating specified fields.
"""
if self.data is None:
raise ValueError("Data not loaded. Call load() first.")
self.data = self.data.with_columns(
pl.concat_str([pl.col(c) for c in by], separator="-").alias(column_name)
)

def add_dummy_labels(self, by: List[str], column_name: str = "label") -> None:
"""
Add a dummy label column (zeros) for each row, compatible with graph splitting.
"""
if self.data is None:
raise ValueError("Data not loaded. Call load() first.")
self.data = self.data.with_columns(
pl.lit(0).alias(column_name)
)
2 changes: 2 additions & 0 deletions unravel/basketball/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .graph_converter import BasketballGraphConverter
from .graph_settings import BasketballGraphSettings, BasketballPitchDimensions
9 changes: 9 additions & 0 deletions unravel/basketball/graphs/features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .node_features import compute_node_features
from .adjacency_matrix import compute_adjacency_matrix
from .edge_features import compute_edge_features

__all__ = [
"compute_node_features",
"compute_adjacency_matrix",
"compute_edge_features",
]
30 changes: 30 additions & 0 deletions unravel/basketball/graphs/features/adjacency_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from typing import List, Any


def compute_adjacency_matrix(
teams: List[Any],
self_loop: bool = True
) -> np.ndarray:
"""
Compute the adjacency matrix based on team membership.

Args:
teams: List of team identifiers of length n_nodes.
self_loop: If True, diagonal entries remain 1; if False, zero out self connections.

Returns:
A: NumPy array of shape (n_nodes, n_nodes) where A[i, j] = 1.0 if nodes
i and j belong to the same team, else 0.0. Diagonal set according to self_loop.
"""
# Convert team list into a NumPy array for vectorized comparisons
arr = np.array(teams)

# Create an n x n boolean matrix of team equality
A = (arr[:, None] == arr[None, :]).astype(float)

# Optionally remove self-connections by zeroing the diagonal
if not self_loop:
np.fill_diagonal(A, 0.0)

return A
24 changes: 24 additions & 0 deletions unravel/basketball/graphs/features/edge_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
from typing import Any


def compute_edge_features(
x: np.ndarray
) -> np.ndarray:
"""
Compute pairwise edge features between nodes.

Args:
x: NumPy array of shape (n_nodes, n_node_features), typically the node feature matrix.

Returns:
e: NumPy array of shape (n_nodes, n_nodes) where e[i, j] is the Euclidean
distance between feature vectors of node i and node j.
"""
# Calculate difference between each pair of node feature vectors
# x[:, None, :] has shape (n, 1, f), x[None, :, :] has shape (1, n, f)
diff = x[:, None, :] - x[None, :, :]

# Compute Euclidean norm along the feature axis, resulting in (n, n)
e = np.linalg.norm(diff, axis=2)
return e
67 changes: 67 additions & 0 deletions unravel/basketball/graphs/features/node_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
from typing import List, Tuple, Any
from unravel.basketball.graphs.graph_settings import BasketballPitchDimensions


def compute_node_features(
records: List[dict],
normalize_coordinates: bool,
pitch_dimensions: BasketballPitchDimensions,
node_feature_cols: List[str] = None,
) -> Tuple[np.ndarray, List[Any]]:
"""
Build the node feature matrix and extract team labels.

Args:
records: List of dicts, each representing one entity in the frame, e.g.:
{
"x": float,
"y": float,
"vx": float,
"vy": float,
"speed": float,
"acceleration": float,
"team": Any,
...
}
normalize_coordinates: If True, scale x by court_length and y by court_width.
pitch_dimensions: BasketballPitchDimensions instance containing court dimensions.
node_feature_cols: List of keys from each record to include as features, in order.
Defaults to ["x", "y", "vx", "vy", "speed", "acceleration"].

Returns:
x_array: NumPy array of shape (n_nodes, n_node_features) with node features.
teams: List of length n_nodes containing the team label for each node.
"""
# Use default feature list if none provided
if node_feature_cols is None:
node_feature_cols = ["x", "y", "vx", "vy", "speed", "acceleration"]

x_list: List[List[float]] = []
teams: List[Any] = []

for rec in records:
features: List[float] = []
for col in node_feature_cols:
# Retrieve raw value (might be None)
val = rec.get(col, 0.0)
# Coerce None → 0.0 to avoid float(None)
if val is None:
val = 0.0

# If normalizing and the feature is a coordinate, scale it
if normalize_coordinates and col in ("x", "y"):
if col == "x":
val = val / pitch_dimensions.court_length
else: # col == "y"
val = val / pitch_dimensions.court_width

features.append(float(val))

x_list.append(features)
# Collect the team label for adjacency construction
teams.append(rec.get("team"))

# Convert list of feature lists to a 2D NumPy array
x_array = np.asarray(x_list, dtype=float)
return x_array, teams
Loading