|
36 | 36 | }, |
37 | 37 | { |
38 | 38 | "cell_type": "code", |
39 | | - "execution_count": null, |
| 39 | + "execution_count": 1, |
40 | 40 | "metadata": {}, |
41 | | - "outputs": [], |
| 41 | + "outputs": [ |
| 42 | + { |
| 43 | + "name": "stdout", |
| 44 | + "output_type": "stream", |
| 45 | + "text": [ |
| 46 | + "\n", |
| 47 | + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\n", |
| 48 | + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", |
| 49 | + "Note: you may need to restart the kernel to use updated packages.\n" |
| 50 | + ] |
| 51 | + } |
| 52 | + ], |
42 | 53 | "source": [ |
43 | 54 | "%pip install unravelsports --quiet" |
44 | 55 | ] |
|
51 | 62 | "\n", |
52 | 63 | "1. Load [Kloppy](https://github.com/PySport/kloppy) dataset. \n", |
53 | 64 | " See [in-depth Tutorial](1_kloppy_gnn_train.ipynb) on how do processes multiple match files, and to see an overview of all possible settings.\n", |
54 | | - "2. Convert to Graph format using `SoccerGraphConverter`\n", |
| 65 | + "2. Convert to Graph format using `SoccerGraphConverterPolars`\n", |
55 | 66 | "3. Create dataset for easy processing with [Spektral](https://graphneural.network/) using `CustomSpektralDataset`" |
56 | 67 | ] |
57 | 68 | }, |
58 | 69 | { |
59 | 70 | "cell_type": "code", |
60 | | - "execution_count": null, |
| 71 | + "execution_count": 2, |
61 | 72 | "metadata": {}, |
62 | 73 | "outputs": [], |
63 | 74 | "source": [ |
64 | | - "from unravel.soccer import SoccerGraphConverter\n", |
| 75 | + "from unravel.soccer import SoccerGraphConverterPolars, KloppyPolarsDataset\n", |
65 | 76 | "from unravel.utils import CustomSpektralDataset\n", |
66 | 77 | "\n", |
67 | | - "from kloppy import skillcorner\n", |
68 | | - "\n", |
69 | | - "from unravel.utils import dummy_labels\n", |
| 78 | + "from kloppy import sportec\n", |
70 | 79 | "\n", |
71 | 80 | "# Load Kloppy dataset\n", |
72 | | - "kloppy_dataset = skillcorner.load_open_data(\n", |
73 | | - " match_id=4039,\n", |
74 | | - " include_empty_frames=False,\n", |
75 | | - " limit=500, # limit to 500 frames in this example\n", |
| 81 | + "kloppy_dataset = sportec.load_open_tracking_data(only_alive=True, limit=500)\n", |
| 82 | + "kloppy_polars_dataset = KloppyPolarsDataset(\n", |
| 83 | + " kloppy_dataset=kloppy_dataset,\n", |
76 | 84 | ")\n", |
| 85 | + "kloppy_polars_dataset.add_dummy_labels()\n", |
| 86 | + "kloppy_polars_dataset.add_graph_ids(by=[\"frame_id\"])\n", |
77 | 87 | "\n", |
78 | | - "# Initialize the Graph Converter, with dataset and labels\n", |
| 88 | + "# Initialize the Graph Converter with dataset\n", |
79 | 89 | "# Here we use the default settings\n", |
80 | | - "converter = SoccerGraphConverter(\n", |
81 | | - " dataset=kloppy_dataset, labels=dummy_labels(kloppy_dataset)\n", |
82 | | - ")\n", |
| 90 | + "converter = SoccerGraphConverterPolars(dataset=kloppy_polars_dataset)\n", |
83 | 91 | "\n", |
84 | 92 | "# Compute the graphs and add them to the CustomSpektralDataset\n", |
85 | 93 | "dataset = CustomSpektralDataset(graphs=converter.to_spektral_graphs())" |
|
96 | 104 | }, |
97 | 105 | { |
98 | 106 | "cell_type": "code", |
99 | | - "execution_count": null, |
| 107 | + "execution_count": 3, |
100 | 108 | "metadata": {}, |
101 | 109 | "outputs": [], |
102 | 110 | "source": [ |
103 | 111 | "from spektral.data import DisjointLoader\n", |
104 | 112 | "\n", |
105 | 113 | "train, test, val = dataset.split_test_train_validation(\n", |
106 | | - " split_train=4, split_test=1, split_validation=1, random_seed=42\n", |
| 114 | + " split_train=4, split_test=1, split_validation=1, random_seed=43\n", |
107 | 115 | ")" |
108 | 116 | ] |
109 | 117 | }, |
|
121 | 129 | }, |
122 | 130 | { |
123 | 131 | "cell_type": "code", |
124 | | - "execution_count": null, |
| 132 | + "execution_count": 4, |
125 | 133 | "metadata": {}, |
126 | | - "outputs": [], |
| 134 | + "outputs": [ |
| 135 | + { |
| 136 | + "name": "stderr", |
| 137 | + "output_type": "stream", |
| 138 | + "text": [ |
| 139 | + "WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.\n" |
| 140 | + ] |
| 141 | + } |
| 142 | + ], |
127 | 143 | "source": [ |
128 | 144 | "from unravel.classifiers import CrystalGraphClassifier\n", |
129 | 145 | "\n", |
|
150 | 166 | }, |
151 | 167 | { |
152 | 168 | "cell_type": "code", |
153 | | - "execution_count": null, |
| 169 | + "execution_count": 5, |
154 | 170 | "metadata": {}, |
155 | | - "outputs": [], |
| 171 | + "outputs": [ |
| 172 | + { |
| 173 | + "name": "stdout", |
| 174 | + "output_type": "stream", |
| 175 | + "text": [ |
| 176 | + "Epoch 1/10\n" |
| 177 | + ] |
| 178 | + }, |
| 179 | + { |
| 180 | + "name": "stderr", |
| 181 | + "output_type": "stream", |
| 182 | + "text": [ |
| 183 | + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311/lib/python3.11/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", |
| 184 | + " warnings.warn(\n" |
| 185 | + ] |
| 186 | + }, |
| 187 | + { |
| 188 | + "name": "stdout", |
| 189 | + "output_type": "stream", |
| 190 | + "text": [ |
| 191 | + "11/11 [==============================] - 1s 16ms/step - loss: 21.7806 - auc: 0.5278 - binary_accuracy: 0.5419 - val_loss: 5.1682 - val_auc: 0.5000 - val_binary_accuracy: 0.5000\n", |
| 192 | + "Epoch 2/10\n", |
| 193 | + " 1/11 [=>............................] - ETA: 0s - loss: 9.2846 - auc: 0.3651 - binary_accuracy: 0.5000WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 3 batches). You may need to use the repeat() function when building your dataset.\n" |
| 194 | + ] |
| 195 | + }, |
| 196 | + { |
| 197 | + "name": "stderr", |
| 198 | + "output_type": "stream", |
| 199 | + "text": [ |
| 200 | + "WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 3 batches). You may need to use the repeat() function when building your dataset.\n" |
| 201 | + ] |
| 202 | + }, |
| 203 | + { |
| 204 | + "name": "stdout", |
| 205 | + "output_type": "stream", |
| 206 | + "text": [ |
| 207 | + "11/11 [==============================] - 0s 6ms/step - loss: 4.5155 - auc: 0.5366 - binary_accuracy: 0.5449\n", |
| 208 | + "Epoch 3/10\n", |
| 209 | + "11/11 [==============================] - 0s 4ms/step - loss: 2.0773 - auc: 0.4515 - binary_accuracy: 0.4731\n", |
| 210 | + "Epoch 4/10\n", |
| 211 | + "11/11 [==============================] - 0s 5ms/step - loss: 1.1006 - auc: 0.5205 - binary_accuracy: 0.5150\n", |
| 212 | + "Epoch 5/10\n", |
| 213 | + "11/11 [==============================] - 0s 4ms/step - loss: 0.9159 - auc: 0.4915 - binary_accuracy: 0.5180\n", |
| 214 | + "Epoch 6/10\n", |
| 215 | + "11/11 [==============================] - 0s 5ms/step - loss: 0.8020 - auc: 0.4873 - binary_accuracy: 0.5060\n", |
| 216 | + "Epoch 7/10\n", |
| 217 | + "11/11 [==============================] - 0s 4ms/step - loss: 0.8067 - auc: 0.4960 - binary_accuracy: 0.5299\n", |
| 218 | + "Epoch 8/10\n", |
| 219 | + "11/11 [==============================] - 0s 6ms/step - loss: 0.7808 - auc: 0.5055 - binary_accuracy: 0.5299\n", |
| 220 | + "Epoch 9/10\n", |
| 221 | + "11/11 [==============================] - 0s 4ms/step - loss: 0.7661 - auc: 0.4937 - binary_accuracy: 0.5060\n", |
| 222 | + "Epoch 10/10\n", |
| 223 | + "11/11 [==============================] - 0s 5ms/step - loss: 0.7406 - auc: 0.5098 - binary_accuracy: 0.5329\n" |
| 224 | + ] |
| 225 | + }, |
| 226 | + { |
| 227 | + "data": { |
| 228 | + "text/plain": [ |
| 229 | + "<keras.src.callbacks.History at 0x39fe49d10>" |
| 230 | + ] |
| 231 | + }, |
| 232 | + "execution_count": 5, |
| 233 | + "metadata": {}, |
| 234 | + "output_type": "execute_result" |
| 235 | + } |
| 236 | + ], |
156 | 237 | "source": [ |
157 | 238 | "from tensorflow.keras.callbacks import EarlyStopping\n", |
158 | 239 | "\n", |
|
186 | 267 | }, |
187 | 268 | { |
188 | 269 | "cell_type": "code", |
189 | | - "execution_count": null, |
| 270 | + "execution_count": 6, |
190 | 271 | "metadata": {}, |
191 | | - "outputs": [], |
| 272 | + "outputs": [ |
| 273 | + { |
| 274 | + "name": "stdout", |
| 275 | + "output_type": "stream", |
| 276 | + "text": [ |
| 277 | + "3/3 [==============================] - 0s 6ms/step - loss: 0.7001 - auc: 0.5000 - binary_accuracy: 0.4819\n" |
| 278 | + ] |
| 279 | + } |
| 280 | + ], |
192 | 281 | "source": [ |
193 | 282 | "loader_te = DisjointLoader(test, epochs=1, shuffle=False, batch_size=batch_size)\n", |
194 | 283 | "results = model.evaluate(loader_te.load())" |
|
207 | 296 | }, |
208 | 297 | { |
209 | 298 | "cell_type": "code", |
210 | | - "execution_count": null, |
| 299 | + "execution_count": 7, |
211 | 300 | "metadata": {}, |
212 | | - "outputs": [], |
| 301 | + "outputs": [ |
| 302 | + { |
| 303 | + "name": "stdout", |
| 304 | + "output_type": "stream", |
| 305 | + "text": [ |
| 306 | + "3/3 [==============================] - 0s 5ms/step\n" |
| 307 | + ] |
| 308 | + } |
| 309 | + ], |
213 | 310 | "source": [ |
214 | 311 | "loader_te = DisjointLoader(test, batch_size=batch_size, epochs=1, shuffle=False)\n", |
215 | 312 | "loaded_pred = model.predict(loader_te.load(), use_multiprocessing=True)" |
|
218 | 315 | ], |
219 | 316 | "metadata": { |
220 | 317 | "kernelspec": { |
221 | | - "display_name": "venv", |
| 318 | + "display_name": ".venv311", |
222 | 319 | "language": "python", |
223 | 320 | "name": "python3" |
224 | 321 | }, |
|
232 | 329 | "name": "python", |
233 | 330 | "nbconvert_exporter": "python", |
234 | 331 | "pygments_lexer": "ipython3", |
235 | | - "version": "3.12.2" |
| 332 | + "version": "3.11.11" |
236 | 333 | } |
237 | 334 | }, |
238 | 335 | "nbformat": 4, |
|
0 commit comments