-
Notifications
You must be signed in to change notification settings - Fork 88
Expand file tree
/
Copy pathassert_util.py
More file actions
67 lines (51 loc) · 2.11 KB
/
assert_util.py
File metadata and controls
67 lines (51 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) 2019 Graphcore Ltd. All rights reserved.
import numpy as np
import re
def assert_result_equals_tensor_value(output, tensor):
"""Searches for a single tensor result in the first line of the output
Searches the first line of the string output for a line with format
'[array([3., 8.], dtype=float32)]' and asserts its equal to the numpy
tensor argument
Args:
output: String containing the string representation of a numpy
tensor
tensor: numpy tensor representing the expected result
Returns:
None
Raises:
Assertion Error: Output is not in correct format
Assertion Error: Output does not contain a string representation
of a numpy array
Assertion Error: Output numpy array does not equal the expected
numpy array
"""
# TODO - np representation over multiple lines
# TODO - large np array output
# TODO - multiple dimension np output
list_regex = r"^\[.*?\]$"
np_array_str_regex = r"array\(.*?, dtype=.*?\)$"
first_line = output.split("\n")[0]
if not re.match(list_regex, first_line):
raise AssertionError(
"Result not in expected string format."
" Expecting stringified list "
" eg. [array([3., 8.], dtype=float32)]"
)
contents = first_line[1:-1]
if not re.match(np_array_str_regex, contents):
raise AssertionError("Expecting numpy representation " "array with dtype " "eg. array([3., 8.], dtype=float32)")
assert contents == np.array_repr(tensor), "Output value {} does not " "equal expected value {}".format(
np.array_repr(contents), tensor
)
def assert_result_equals_string(output, expected):
"""Checks output line equals expected string
Args:
output: String representing the output of a test.
expected: String of expected result.
Returns:
None
Raises:
Assertion Error: Output string does not equal the expected
string
"""
assert output == expected, "Output string {} does not " "equal expected string {}".format(output, expected)