Skip to content

Commit cf03eb7

Browse files
author
Ashley Scillitoe
authored
Fix to allow config.toml to be loaded with [meta] not present (#591)
1 parent af57b12 commit cf03eb7

File tree

5 files changed

+60
-13
lines changed

5 files changed

+60
-13
lines changed

alibi_detect/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,10 @@ def from_config(cls, config: dict):
144144
config
145145
A config dictionary matching the schema's in :class:`~alibi_detect.saving.schemas`.
146146
"""
147-
# Check for exisiting version_warning. meta is pop'd as don't want to pass as arg/kwarg
148-
version_warning = config.pop('meta', {}).pop('version_warning', False)
147+
# Check for existing version_warning. meta is pop'd as don't want to pass as arg/kwarg
148+
meta = config.pop('meta', None)
149+
meta = {} if meta is None else meta # Needed because pydantic sets meta=None if it is missing from the config
150+
version_warning = meta.pop('version_warning', False)
149151
# Init detector
150152
detector = cls(**config)
151153
# Add version_warning

alibi_detect/saving/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class DetectorConfig(CustomBaseModel):
100100
"Name of the detector e.g. `MMDDrift`."
101101
backend: Literal['tensorflow', 'pytorch', 'sklearn'] = 'tensorflow'
102102
"The detector backend."
103-
meta: Optional[MetaData]
103+
meta: Optional[MetaData] = None
104104
"Config metadata. Should not be edited."
105105
# Note: Although not all detectors have a backend, we define in base class as `backend` also determines
106106
# whether tf or torch models used for preprocess_fn.

alibi_detect/saving/tests/test_saving.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from typing import Callable
1212

13+
import toml
1314
import dill
1415
import numpy as np
1516
import pytest
@@ -61,6 +62,16 @@
6162
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
6263
REGISTERED_OBJECTS = registry.get_all()
6364

65+
# Define a detector config dict
66+
MMD_CFG = {
67+
'name': 'MMDDrift',
68+
'x_ref': np.array([[-0.30074928], [1.50240758], [0.43135768], [2.11295779], [0.79684913]]),
69+
'p_val': 0.05,
70+
'n_permutations': 150,
71+
'data_type': 'tabular'
72+
}
73+
CFGS = [MMD_CFG]
74+
6475
# TODO - future: Some of the fixtures can/should be moved elsewhere (i.e. if they can be recycled for use elsewhere)
6576

6677

@@ -259,6 +270,32 @@ def preprocess_hiddenoutput(classifier, backend):
259270
return preprocess_fn
260271

261272

273+
@parametrize('cfg', CFGS)
274+
def test_load_simple_config(cfg, tmp_path):
275+
"""
276+
Test that a bare-bones `config.toml` without a [meta] field can be loaded by `load_detector`.
277+
"""
278+
save_dir = tmp_path
279+
x_ref_path = str(save_dir.joinpath('x_ref.npy'))
280+
cfg_path = save_dir.joinpath('config.toml')
281+
# Save x_ref in config.toml
282+
x_ref = cfg['x_ref']
283+
np.save(x_ref_path, x_ref)
284+
cfg['x_ref'] = 'x_ref.npy'
285+
# Save config.toml then load it
286+
with open(cfg_path, 'w') as f:
287+
toml.dump(cfg, f)
288+
cd = load_detector(cfg_path)
289+
assert cd.__class__.__name__ == cfg['name']
290+
# Get config and compare to original (orginal cfg not fully spec'd so only compare items that are present)
291+
cfg_new = cd.get_config()
292+
for k, v in cfg.items():
293+
if k == 'x_ref':
294+
assert v == 'x_ref.npy'
295+
else:
296+
assert v == cfg_new[k]
297+
298+
262299
@parametrize('preprocess_fn', [preprocess_custom, preprocess_hiddenoutput])
263300
@parametrize_with_cases("data", cases=ContinuousData, prefix='data_')
264301
def test_save_ksdrift(data, preprocess_fn, tmp_path):

alibi_detect/saving/tests/test_validate.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from alibi_detect.saving import validate_config
66
from alibi_detect.saving.saving import X_REF_FILENAME
77
from alibi_detect.version import __config_spec__, __version__
8+
from copy import deepcopy
89

910
# Define a detector config dict
1011
mmd_cfg = {
@@ -16,19 +17,14 @@
1617
'x_ref': np.array([[-0.30074928], [1.50240758], [0.43135768], [2.11295779], [0.79684913]]),
1718
'p_val': 0.05
1819
}
19-
cfgs = [mmd_cfg]
20-
n_tests = len(cfgs)
2120

21+
# Define a detector config dict without meta (as simple as it gets!)
22+
mmd_cfg_nometa = deepcopy(mmd_cfg)
23+
mmd_cfg_nometa.pop('meta')
2224

23-
@pytest.fixture
24-
def select_cfg(request):
25-
return cfgs[request.param]
26-
27-
28-
@pytest.mark.parametrize('select_cfg', list(range(n_tests)), indirect=True)
29-
def test_validate_config(select_cfg):
30-
cfg = select_cfg
3125

26+
@pytest.mark.parametrize('cfg', [mmd_cfg])
27+
def test_validate_config(cfg):
3228
# Original cfg
3329
# Check original cfg doesn't raise errors
3430
cfg_full = validate_config(cfg, resolved=True)
@@ -81,3 +77,14 @@ def test_validate_config(select_cfg):
8177
with pytest.raises(ValidationError):
8278
cfg_err = validate_config(cfg_err, resolved=True)
8379
assert not cfg_err.get('meta').get('version_warning')
80+
81+
82+
@pytest.mark.parametrize('cfg', [mmd_cfg_nometa])
83+
def test_validate_config_wo_meta(cfg):
84+
# Check a config w/o a meta dict can be validated
85+
_ = validate_config(cfg, resolved=True)
86+
87+
# Check the unresolved case
88+
cfg_unres = cfg.copy()
89+
cfg_unres['x_ref'] = X_REF_FILENAME
90+
_ = validate_config(cfg_unres)

alibi_detect/saving/validate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def validate_config(cfg: dict, resolved: bool = False) -> dict:
3838

3939
# Get meta data
4040
meta = cfg.get('meta')
41+
meta = {} if meta is None else meta # Needed because pydantic sets meta=None if it is missing from the config
4142
version_warning = meta.get('version_warning', False)
4243
version = meta.get('version', None)
4344
config_spec = meta.get('config_spec', None)

0 commit comments

Comments
 (0)