Skip to content

Commit 230232f

Browse files
Merge pull request #30 from UnravelSports/fix/sample_rate
fix sample rate
2 parents de4511c + ade8052 commit 230232f

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

tests/test_bigdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_settings(self, gnnc_non_default, non_default_arguments):
200200
assert 1 == 1
201201

202202
data = spektral_graphs
203-
assert len(data) == 132
203+
assert len(data) == 130
204204
assert isinstance(data[0], Graph)
205205

206206
assert settings.pitch_dimensions.pitch_length == 120.0

unravel/american_football/graphs/graph_converter.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
else dataset._graph_id_column
6464
)
6565

66-
self.sample = 1.0 / kwargs.get("sample_rate", 1.0)
66+
self.sample_rate = kwargs.get("sample_rate", None)
6767
self.chunk_size = chunk_size
6868
self.attacking_non_qb_node_value = attacking_non_qb_node_value
6969
self.graph_feature_cols = graph_feature_cols
@@ -73,6 +73,27 @@ def __init__(
7373

7474
self._sport_specific_checks()
7575

76+
self._sample()
77+
self._shuffle()
78+
79+
def _sample(self):
80+
if self.sample_rate is None:
81+
return
82+
else:
83+
self.dataset = self.dataset.filter(
84+
pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0
85+
)
86+
87+
def _shuffle(self):
88+
if isinstance(self.settings.random_seed, int):
89+
self.dataset = self.dataset.sample(
90+
fraction=1.0, seed=self.settings.random_seed
91+
)
92+
elif self.settings.random_seed == True:
93+
self.dataset = self.dataset.sample(fraction=1.0)
94+
else:
95+
pass
96+
7697
def _sport_specific_checks(self):
7798
def __remove_with_missing_values(min_object_count: int = 10):
7899
cs = (
@@ -338,7 +359,6 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]:
338359
"id": chunk[self.graph_id_column][i],
339360
}
340361
for i in range(len(chunk))
341-
if i % self.sample == 0
342362
]
343363

344364
graph_df = self._convert()

unravel/soccer/graphs/graph_converter_pl.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def __post_init__(self):
8080
if not isinstance(self.dataset, KloppyPolarsDataset):
8181
raise ValueError("dataset should be of type KloppyPolarsDataset...")
8282

83-
self.sample = 1.0 if self.sample_rate is None else 1.0 / self.sample_rate
84-
8583
self.pitch_dimensions: MetricPitchDimensions = (
8684
self.dataset.settings.pitch_dimensions
8785
)
@@ -114,8 +112,17 @@ def __post_init__(self):
114112
else:
115113
self.dataset = self._remove_incomplete_frames()
116114

115+
self._sample()
117116
self._shuffle()
118117

118+
def _sample(self):
119+
if self.sample_rate is None:
120+
return
121+
else:
122+
self.dataset = self.dataset.filter(
123+
pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0
124+
)
125+
119126
def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]):
120127
for i, func in enumerate(funcs):
121128
# Check if it has the attributes added by the decorator
@@ -523,7 +530,6 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]:
523530
"id": chunk[self.graph_id_column][i],
524531
}
525532
for i in range(len(chunk))
526-
if i % self.sample == 0
527533
]
528534

529535
graph_df = self._convert()

0 commit comments

Comments
 (0)