|
10 | 10 | from pathlib import Path |
11 | 11 | from typing import Callable |
12 | 12 |
|
| 13 | +import toml |
13 | 14 | import dill |
14 | 15 | import numpy as np |
15 | 16 | import pytest |
|
61 | 62 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
62 | 63 | REGISTERED_OBJECTS = registry.get_all() |
63 | 64 |
|
| 65 | +# Define a detector config dict |
| 66 | +MMD_CFG = { |
| 67 | + 'name': 'MMDDrift', |
| 68 | + 'x_ref': np.array([[-0.30074928], [1.50240758], [0.43135768], [2.11295779], [0.79684913]]), |
| 69 | + 'p_val': 0.05, |
| 70 | + 'n_permutations': 150, |
| 71 | + 'data_type': 'tabular' |
| 72 | +} |
| 73 | +CFGS = [MMD_CFG] |
| 74 | + |
64 | 75 | # TODO - future: Some of the fixtures can/should be moved elsewhere (i.e. if they can be recycled for use elsewhere) |
65 | 76 |
|
66 | 77 |
|
@@ -259,6 +270,32 @@ def preprocess_hiddenoutput(classifier, backend): |
259 | 270 | return preprocess_fn |
260 | 271 |
|
261 | 272 |
|
| 273 | +@parametrize('cfg', CFGS) |
| 274 | +def test_load_simple_config(cfg, tmp_path): |
| 275 | + """ |
| 276 | + Test that a bare-bones `config.toml` without a [meta] field can be loaded by `load_detector`. |
| 277 | + """ |
| 278 | + save_dir = tmp_path |
| 279 | + x_ref_path = str(save_dir.joinpath('x_ref.npy')) |
| 280 | + cfg_path = save_dir.joinpath('config.toml') |
| 281 | + # Save x_ref in config.toml |
| 282 | + x_ref = cfg['x_ref'] |
| 283 | + np.save(x_ref_path, x_ref) |
| 284 | + cfg['x_ref'] = 'x_ref.npy' |
| 285 | + # Save config.toml then load it |
| 286 | + with open(cfg_path, 'w') as f: |
| 287 | + toml.dump(cfg, f) |
| 288 | + cd = load_detector(cfg_path) |
| 289 | + assert cd.__class__.__name__ == cfg['name'] |
| 290 | + # Get config and compare to original (orginal cfg not fully spec'd so only compare items that are present) |
| 291 | + cfg_new = cd.get_config() |
| 292 | + for k, v in cfg.items(): |
| 293 | + if k == 'x_ref': |
| 294 | + assert v == 'x_ref.npy' |
| 295 | + else: |
| 296 | + assert v == cfg_new[k] |
| 297 | + |
| 298 | + |
262 | 299 | @parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput]) |
263 | 300 | @parametrize_with_cases("data", cases=ContinuousData, prefix='data_') |
264 | 301 | def test_save_ksdrift(data, preprocess_fn, tmp_path): |
|
0 commit comments