Skip to content

Commit 6d62e66

Browse files
committed
wave: sampwidth for IEEE Float must be 4 or 8
This is also similar to what libsndfile does
1 parent 689ae2d commit 6d62e66

File tree

3 files changed

+71
-8
lines changed

3 files changed

+71
-8
lines changed

Doc/library/wave.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ Wave_write Objects
218218

219219
Set the sample width to *n* bytes.
220220

221+
For :data:`WAVE_FORMAT_IEEE_FLOAT`, only 4-byte (32-bit) and
222+
8-byte (64-bit) sample widths are supported.
223+
221224

222225
.. method:: getsampwidth()
223226

@@ -273,6 +276,9 @@ Wave_write Objects
273276
Supported values are :data:`WAVE_FORMAT_PCM` and
274277
:data:`WAVE_FORMAT_IEEE_FLOAT`.
275278

279+
When setting :data:`WAVE_FORMAT_IEEE_FLOAT`, the sample width must be
280+
4 or 8 bytes.
281+
276282

277283
.. method:: getformat()
278284

@@ -288,6 +294,8 @@ Wave_write Objects
288294
For backwards compatibility, a 6-item tuple without *format* is also
289295
accepted and defaults to :data:`WAVE_FORMAT_PCM`.
290296

297+
For ``format=WAVE_FORMAT_IEEE_FLOAT``, *sampwidth* must be 4 or 8.
298+
291299

292300
.. method:: getparams()
293301

Lib/test/test_wave.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,20 +189,31 @@ def test_setparams_7_tuple_uses_format(self):
189189
self.addCleanup(unlink, filename)
190190

191191
with wave.open(filename, 'wb') as w:
192-
w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed',
192+
w.setparams((1, 4, 22050, 0, 'NONE', 'not compressed',
193193
wave.WAVE_FORMAT_IEEE_FLOAT))
194194
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
195195

196+
def test_setparams_7_tuple_ieee_64bit_sampwidth(self):
197+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
198+
filename = fp.name
199+
self.addCleanup(unlink, filename)
200+
201+
with wave.open(filename, 'wb') as w:
202+
w.setparams((1, 8, 22050, 0, 'NONE', 'not compressed',
203+
wave.WAVE_FORMAT_IEEE_FLOAT))
204+
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
205+
self.assertEqual(w.getsampwidth(), 8)
206+
196207
def test_getparams_backward_compatible_shape(self):
197208
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
198209
filename = fp.name
199210
self.addCleanup(unlink, filename)
200211

201212
with wave.open(filename, 'wb') as w:
202-
w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed',
213+
w.setparams((1, 4, 22050, 0, 'NONE', 'not compressed',
203214
wave.WAVE_FORMAT_IEEE_FLOAT))
204215
params = w.getparams()
205-
self.assertEqual(params, (1, 2, 22050, 0, 'NONE', 'not compressed'))
216+
self.assertEqual(params, (1, 4, 22050, 0, 'NONE', 'not compressed'))
206217

207218
def test_getformat_setformat(self):
208219
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
@@ -211,12 +222,51 @@ def test_getformat_setformat(self):
211222

212223
with wave.open(filename, 'wb') as w:
213224
w.setnchannels(1)
214-
w.setsampwidth(2)
225+
w.setsampwidth(4)
215226
w.setframerate(22050)
216227
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_PCM)
217228
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
218229
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
219230

231+
def test_setformat_ieee_requires_32_or_64_bit_sampwidth(self):
232+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
233+
filename = fp.name
234+
self.addCleanup(unlink, filename)
235+
236+
with wave.open(filename, 'wb') as w:
237+
w.setnchannels(1)
238+
w.setsampwidth(2)
239+
w.setframerate(22050)
240+
with self.assertRaisesRegex(wave.Error,
241+
'unsupported sample width for IEEE float format'):
242+
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
243+
244+
def test_setsampwidth_ieee_requires_32_or_64_bit(self):
245+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
246+
filename = fp.name
247+
self.addCleanup(unlink, filename)
248+
249+
with wave.open(filename, 'wb') as w:
250+
w.setnchannels(1)
251+
w.setframerate(22050)
252+
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
253+
with self.assertRaisesRegex(wave.Error,
254+
'unsupported sample width for IEEE float format'):
255+
w.setsampwidth(2)
256+
w.setsampwidth(4)
257+
258+
def test_setsampwidth_ieee_accepts_64_bit(self):
259+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
260+
filename = fp.name
261+
self.addCleanup(unlink, filename)
262+
263+
with wave.open(filename, 'wb') as w:
264+
w.setnchannels(1)
265+
w.setframerate(22050)
266+
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
267+
w.setsampwidth(8)
268+
self.assertEqual(w.getsampwidth(), 8)
269+
220270
def test_read_getformat(self):
221271
b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
222272
b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 1, 11025, 11025, 1, 8)
@@ -297,10 +347,10 @@ def test_ieee_float_has_fact_chunk(self):
297347

298348
with wave.open(filename, 'wb') as w:
299349
w.setnchannels(1)
300-
w.setsampwidth(2)
350+
w.setsampwidth(4)
301351
w.setframerate(22050)
302352
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
303-
w.writeframes(b'\x00\x00' * nframes)
353+
w.writeframes(b'\x00\x00\x00\x00' * nframes)
304354

305355
with open(filename, 'rb') as f:
306356
f.read(12)

Lib/wave.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,10 @@ def getnchannels(self):
506506
def setsampwidth(self, sampwidth):
507507
if self._datawritten:
508508
raise Error('cannot change parameters after starting to write')
509-
if sampwidth < 1 or sampwidth > 4:
509+
if self._format == WAVE_FORMAT_IEEE_FLOAT:
510+
if sampwidth not in (4, 8):
511+
raise Error('unsupported sample width for IEEE float format')
512+
elif sampwidth < 1 or sampwidth > 4:
510513
raise Error('bad sample width')
511514
self._sampwidth = sampwidth
512515

@@ -548,6 +551,8 @@ def setformat(self, format):
548551
raise Error('cannot change parameters after starting to write')
549552
if format not in (WAVE_FORMAT_IEEE_FLOAT, WAVE_FORMAT_PCM):
550553
raise Error('unsupported wave format')
554+
if format == WAVE_FORMAT_IEEE_FLOAT and self._sampwidth and self._sampwidth not in (4, 8):
555+
raise Error('unsupported sample width for IEEE float format')
551556
self._format = format
552557

553558
def getformat(self):
@@ -568,11 +573,11 @@ def setparams(self, params):
568573
else:
569574
nchannels, sampwidth, framerate, nframes, comptype, compname, format = params
570575
self.setnchannels(nchannels)
576+
self.setformat(format)
571577
self.setsampwidth(sampwidth)
572578
self.setframerate(framerate)
573579
self.setnframes(nframes)
574580
self.setcomptype(comptype, compname)
575-
self.setformat(format)
576581

577582
def getparams(self):
578583
if not self._nchannels or not self._sampwidth or not self._framerate:

0 commit comments

Comments
 (0)