diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 967cdc21..d9fdeace 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -374,7 +374,15 @@ func NewDecryptReaderAt(key []byte, src io.ReaderAt, size int64) (*DecryptReader finalChunkOff := finalChunkIndex * encChunkSize finalChunkSize := size - finalChunkOff finalChunk := make([]byte, finalChunkSize) - if _, err := src.ReadAt(finalChunk, finalChunkOff); err != nil { + nn, err := src.ReadAt(finalChunk, finalChunkOff) + if err == io.EOF { + if int64(nn) != finalChunkSize { + err = io.ErrUnexpectedEOF + } else { + err = nil + } + } + if err != nil { return nil, fmt.Errorf("failed to read final chunk: %w", err) } nonce := nonceForChunk(finalChunkIndex) diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 7d0f7b37..e24a32e5 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -411,6 +411,58 @@ func TestDecryptReaderAt(t *testing.T) { checkRead("cross boundary 1-2", int64(2*cs-50), 100, 100, false, true) } +// eofReaderAt wraps an io.ReaderAt and forces it to return io.EOF if the read +// successfully reads up to the exact end of the file. +type eofReaderAt struct { + r io.ReaderAt + size int64 +} + +func (e *eofReaderAt) ReadAt(p []byte, off int64) (int, error) { + n, err := e.r.ReadAt(p, off) + if err == nil && off+int64(n) == e.size { + err = io.EOF + } + return n, err +} + +func TestDecryptReaderAtEOF(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + plaintext := make([]byte, 100) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + // Simulate an io.ReaderAt that returns io.EOF exactly at the end + er := &eofReaderAt{ + r: bytes.NewReader(ciphertext), + size: int64(len(ciphertext)), + } + + // This should succeed now, clearing the io.EOF. + _, err = stream.NewDecryptReaderAt(key, er, int64(len(ciphertext))) + if err != nil { + t.Fatalf("NewDecryptReaderAt failed on EOF: %v", err) + } +} + func TestDecryptReaderAtEmpty(t *testing.T) { key := make([]byte, chacha20poly1305.KeySize) if _, err := rand.Read(key); err != nil {