diff --git a/tests/test_read_write.py b/tests/test_read_write.py index e658a08..2aea3ca 100644 --- a/tests/test_read_write.py +++ b/tests/test_read_write.py @@ -58,6 +58,40 @@ def test_round_trip_array_datatypes(): np.testing.assert_array_almost_equal(read_data, test_data, decimal=4) +def test_write_contiguous_array_succeeds(empty_temp_om_file): + data = np.arange(24, dtype=np.float32).reshape(4, 6) + assert data.flags["C_CONTIGUOUS"] + + writer = omfiles.OmFileWriter(empty_temp_om_file) + variable = writer.write_array(data, chunks=[2, 3], scale_factor=10000.0) + writer.close(variable) + + reader = omfiles.OmFileReader(empty_temp_om_file) + read_data = reader[:] + reader.close() + + assert read_data.shape == data.shape + assert read_data.dtype == data.dtype + np.testing.assert_array_almost_equal(read_data, data, decimal=4) + + +@pytest.mark.parametrize( + "data", + [ + np.arange(24, dtype=np.float32).reshape(4, 6).T, + np.arange(24, dtype=np.float32).reshape(4, 6)[:, ::2], + ], +) +def test_write_non_contiguous_array_raises(data, empty_temp_om_file): + assert not data.flags["C_CONTIGUOUS"] + + writer = omfiles.OmFileWriter(empty_temp_om_file) + with pytest.raises(RuntimeError) as exc_info: + writer.write_array(data, chunks=[1] * data.ndim, scale_factor=10000.0) + + assert "Array not contiguous" == exc_info.value.args[0] + + def test_write_hierarchical_file(empty_temp_om_file): # Create test data root_data = np.random.rand(10, 10).astype(np.float32)