From 8471fc2142ed476f902f58b50a527e2d66229718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Mon, 6 Apr 2026 13:32:10 +0000 Subject: [PATCH 1/3] AVRO-4241: [Java] BinaryDecoder should verify available bytes before reading Add ensureAvailableBytes() pre-check in readString, readBytes, readArrayStart, arrayNext, readMapStart, and mapNext to verify the source has sufficient data before proceeding. Byte-array-backed sources return an exact remaining count. Stream-backed sources return buffered bytes plus InputStream.available(), which is reliable for the finite streams used by DataFileReader and DataFileStream. Includes regression tests and updated array/map limit tests. --- .../avro/generic/GenericDatumReader.java | 120 ++++++++++- .../org/apache/avro/io/BinaryDecoder.java | 64 ++++++ .../main/java/org/apache/avro/io/Decoder.java | 10 + .../org/apache/avro/io/ValidatingDecoder.java | 5 + .../avro/util/ByteBufferInputStream.java | 12 ++ .../avro/generic/TestGenericDatumReader.java | 186 ++++++++++++++++++ .../org/apache/avro/io/TestBinaryDecoder.java | 45 +++++ 7 files changed, 440 insertions(+), 2 deletions(-) diff --git a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java index ce0646a82b8..e09dc95d2dd 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java +++ b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java @@ -17,12 +17,16 @@ */ package org.apache.avro.generic; +import java.io.EOFException; import java.io.IOException; import java.lang.reflect.Constructor; import java.nio.ByteBuffer; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; @@ -291,6 +295,7 @@ protected Object readArray(Object old, Schema expected, ResolvingDecoder in) thr long l = in.readArrayStart(); long base = 0; if (l > 0) { + ensureAvailableCollectionBytes(in, l, expectedType); LogicalType logicalType = expectedType.getLogicalType(); Conversion conversion = getData().getConversionFor(logicalType); Object array = newArray(old, (int) l, expected); @@ -306,13 +311,25 @@ protected Object readArray(Object old, Schema expected, ResolvingDecoder in) thr } } base += l; - } while ((l = in.arrayNext()) > 0); + } while ((l = arrayNext(in, expectedType)) > 0); return pruneArray(array); } else { return pruneArray(newArray(old, 0, expected)); } } + /** + * Reads the next array block count and validates remaining bytes before the + * caller allocates storage. + */ + private long arrayNext(ResolvingDecoder in, Schema elementType) throws IOException { + long l = in.arrayNext(); + if (l > 0) { + ensureAvailableCollectionBytes(in, l, elementType); + } + return l; + } + private Object pruneArray(Object object) { if (object instanceof GenericArray) { ((GenericArray) object).prune(); @@ -348,6 +365,9 @@ protected Object readMap(Object old, Schema expected, ResolvingDecoder in) throw long l = in.readMapStart(); LogicalType logicalType = eValue.getLogicalType(); Conversion conversion = getData().getConversionFor(logicalType); + if (l > 0) { + ensureAvailableMapBytes(in, l, eValue); + } Object map = newMap(old, (int) l); if (l > 0) { do { @@ -361,11 +381,40 @@ protected Object readMap(Object old, Schema expected, ResolvingDecoder in) throw addToMap(map, readMapKey(null, expected, in), readWithoutConversion(null, eValue, in)); } } - } while ((l = in.mapNext()) > 0); + } while ((l = mapNext(in, eValue)) > 0); } return map; } + /** + * Reads the next map block count and validates remaining bytes before the + * caller allocates storage. + */ + private long mapNext(ResolvingDecoder in, Schema valueType) throws IOException { + long l = in.mapNext(); + if (l > 0) { + ensureAvailableMapBytes(in, l, valueType); + } + return l; + } + + /** + * Validates remaining bytes for a map block. Each map entry has a string key + * (at least 1 byte for the length varint) plus a value, so the minimum bytes + * per entry is {@code 1 + minBytesPerElement(valueSchema)}. + */ + private static void ensureAvailableMapBytes(Decoder decoder, long count, Schema valueSchema) throws EOFException { + // Map keys are always strings: at least 1 byte for the length varint + int minBytesPerEntry = 1 + minBytesPerElement(valueSchema); + if (count > 0) { + int remaining = decoder.remainingBytes(); + if (remaining >= 0 && count * (long) minBytesPerEntry > remaining) { + throw new EOFException("Map claims " + count + " entries with at least " + minBytesPerEntry + + " bytes each, but only " + remaining + " bytes are available"); + } + } + } + /** * Called by the default implementation of {@link #readMap} to read a key value. * The default implementation returns delegates to @@ -384,6 +433,73 @@ protected void addToMap(Object map, Object key, Object value) { ((Map) map).put(key, value); } + /** + * Returns the minimum number of bytes required to encode a single value of the + * given schema in Avro binary format. Used to validate that the decoder has + * enough data remaining before allocating collection backing arrays. + *

+ * Returns 0 for types whose binary encoding is empty ({@code null}, zero-length + * {@code fixed}, records with only zero-byte fields). Returns a positive value + * for all other types. + */ + static int minBytesPerElement(Schema schema) { + return minBytesPerElement(schema, Collections.newSetFromMap(new IdentityHashMap<>())); + } + + private static int minBytesPerElement(Schema schema, Set visited) { + switch (schema.getType()) { + case NULL: + return 0; + case FIXED: + return schema.getFixedSize(); + case FLOAT: + return 4; + case DOUBLE: + return 8; + case RECORD: + if (!visited.add(schema)) { + return 0; // break recursion for self-referencing schemas + } + long sum = 0; + for (Schema.Field f : schema.getFields()) { + sum += minBytesPerElement(f.schema(), visited); + if (sum >= Integer.MAX_VALUE) { + sum = Integer.MAX_VALUE; + break; + } + } + visited.remove(schema); + return (int) sum; + case UNION: + // The branch index varint is always at least 1 byte + return 1; + default: + // BOOLEAN, INT, LONG, ENUM, STRING, BYTES, ARRAY, MAP are all >= 1 byte + return 1; + } + } + + /** + * Validates that the decoder has enough remaining bytes to hold {@code count} + * elements of the given schema, assuming each element requires at least + * {@link #minBytesPerElement} bytes. Throws {@link EOFException} if the decoder + * reports fewer remaining bytes than required. + *

+ * This check prevents out-of-memory errors from pre-allocating huge backing + * arrays when the source data is truncated or malicious. + */ + private static void ensureAvailableCollectionBytes(Decoder decoder, long count, Schema elementSchema) + throws EOFException { + int minBytes = minBytesPerElement(elementSchema); + if (minBytes > 0 && count > 0) { + int remaining = decoder.remainingBytes(); + if (remaining >= 0 && count * (long) minBytes > remaining) { + throw new EOFException("Collection claims " + count + " elements with at least " + minBytes + + " bytes each, but only " + remaining + " bytes are available"); + } + } + } + /** * Called to read a fixed value. May be overridden for alternate fixed * representations. By default, returns {@link GenericFixed}. diff --git a/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java b/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java index 827b2fea3c7..22d86ca6504 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java +++ b/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java @@ -20,8 +20,10 @@ import org.apache.avro.AvroRuntimeException; import org.apache.avro.InvalidNumberEncodingException; import org.apache.avro.SystemLimitException; +import org.apache.avro.util.ByteBufferInputStream; import org.apache.avro.util.Utf8; +import java.io.ByteArrayInputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; @@ -295,6 +297,7 @@ public double readDouble() throws IOException { @Override public Utf8 readString(Utf8 old) throws IOException { int length = SystemLimitException.checkMaxStringLength(readLong()); + ensureAvailableBytes(length); Utf8 result = (old != null ? old : new Utf8()); result.setByteLength(length); if (0 != length) { @@ -318,6 +321,7 @@ public void skipString() throws IOException { @Override public ByteBuffer readBytes(ByteBuffer old) throws IOException { int length = SystemLimitException.checkMaxBytesLength(readLong()); + ensureAvailableBytes(length); final ByteBuffer result; if (old != null && length <= old.capacity()) { result = old; @@ -508,6 +512,21 @@ public boolean isEnd() throws IOException { return (0 == read); } + /** + * Returns the total number of bytes remaining that can be read from this + * decoder (including any buffered bytes), or {@code -1} if the total is + * unknown. + *

+ * Byte-array-backed decoders return an exact count. InputStream-backed decoders + * return an exact count only when the wrapped stream can report one. + *

+ * {@link DirectBinaryDecoder} always returns {@code -1}. + */ + @Override + public int remainingBytes() { + return source != null ? source.remainingBytes() : -1; + } + /** * Ensures that buf[pos + num - 1] is not out of the buffer array bounds. * However, buf[pos + num -1] may be >= limit if there is not enough data left @@ -530,6 +549,27 @@ private void ensureBounds(int num) throws IOException { } } + /** + * Validates that the source has at least {@code length} bytes remaining before + * proceeding. Throws early if the declared length is inconsistent with the + * available data. + *

+ * This check is only applied when the decoder knows the exact remaining byte + * count. + * + * @param length the number of bytes expected to be available + * @throws EOFException if the source is known to have fewer bytes remaining + */ + private void ensureAvailableBytes(int length) throws EOFException { + if (source != null && length > 0) { + int remaining = source.remainingBytes(); + if (remaining >= 0 && length > remaining) { + throw new EOFException( + "Attempted to read " + length + " bytes, but only " + remaining + " bytes are available"); + } + } + } + /** * Returns an {@link java.io.InputStream} that is aware of any buffering that * may occur in this BinaryDecoder. Readers that need to interleave decoding @@ -664,6 +704,12 @@ protected ByteSource() { abstract boolean isEof(); + /** + * Returns the total number of bytes remaining that can be read from this source + * (including any buffered bytes), or {@code -1} if the total is unknown. + */ + protected abstract int remainingBytes(); + protected void attach(int bufferSize, BinaryDecoder decoder) { decoder.buf = new byte[bufferSize]; decoder.pos = 0; @@ -910,6 +956,19 @@ public boolean isEof() { return isEof; } + @Override + protected int remainingBytes() { + int buffered = ba.getLim() - ba.getPos(); + try { + if (in.getClass() == ByteArrayInputStream.class || in.getClass() == ByteBufferInputStream.class) { + return buffered + in.available(); + } + } catch (IOException e) { + return -1; + } + return -1; + } + @Override public void close() throws IOException { in.close(); @@ -1028,5 +1087,10 @@ public boolean isEof() { int remaining = ba.getLim() - ba.getPos(); return (remaining == 0); } + + @Override + protected int remainingBytes() { + return ba.getLim() - ba.getPos(); + } } } diff --git a/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java b/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java index 11fc28d762e..80640a61aa0 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java +++ b/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java @@ -299,4 +299,14 @@ public void readFixed(byte[] bytes) throws IOException { * type of the next value to be read */ public abstract int readIndex() throws IOException; + + /** + * Returns the total number of bytes remaining that can be read from this + * decoder, or {@code -1} if the total is unknown. Implementations that can + * determine remaining capacity (for example, byte-array-backed decoders) should + * override this method. The default returns {@code -1}. + */ + public int remainingBytes() { + return -1; + } } diff --git a/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java b/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java index dbee4458575..26f79a16ff2 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java +++ b/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java @@ -246,4 +246,9 @@ public int readIndex() throws IOException { public Symbol doAction(Symbol input, Symbol top) throws IOException { return null; } + + @Override + public int remainingBytes() { + return in != null ? in.remainingBytes() : -1; + } } diff --git a/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java b/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java index 6abb62015dc..375abc23fbf 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java +++ b/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java @@ -65,6 +65,18 @@ public int read(byte[] b, int off, int len) throws IOException { } } + @Override + public int available() throws IOException { + long remaining = 0; + for (int i = current; i < buffers.size(); i++) { + remaining += buffers.get(i).remaining(); + if (remaining >= Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } + } + return (int) remaining; + } + /** * Read a buffer from the input without copying, if possible. */ diff --git a/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java index f74dab95b0f..5586b828999 100644 --- a/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java +++ b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java @@ -18,15 +18,24 @@ package org.apache.avro.generic; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import java.io.ByteArrayOutputStream; +import java.io.EOFException; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.avro.Schema; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.EncoderFactory; import org.junit.jupiter.api.Test; public class TestGenericDatumReader { @@ -117,4 +126,181 @@ private void sleep() { } } } + + // --- minBytesPerElement tests --- + + @Test + void testMinBytesPerElementPrimitives() { + assertEquals(0, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.NULL))); + assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.BOOLEAN))); + assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.INT))); + assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.LONG))); + assertEquals(4, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.FLOAT))); + assertEquals(8, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.DOUBLE))); + assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.STRING))); + assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.BYTES))); + } + + @Test + void testMinBytesPerElementFixed() { + assertEquals(0, GenericDatumReader.minBytesPerElement(Schema.createFixed("ZeroFixed", null, "test", 0))); + assertEquals(5, GenericDatumReader.minBytesPerElement(Schema.createFixed("FiveFixed", null, "test", 5))); + assertEquals(16, GenericDatumReader.minBytesPerElement(Schema.createFixed("SixteenFixed", null, "test", 16))); + } + + @Test + void testMinBytesPerElementUnion() { + // Union always >= 1 byte (branch index varint) + Schema nullableInt = Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT)); + assertEquals(1, GenericDatumReader.minBytesPerElement(nullableInt)); + } + + @Test + void testMinBytesPerElementRecord() { + // Empty record = 0 bytes + Schema emptyRecord = Schema.createRecord("Empty", null, "test", false); + emptyRecord.setFields(Collections.emptyList()); + assertEquals(0, GenericDatumReader.minBytesPerElement(emptyRecord)); + + // Record with a single non-null field >= 1 byte + Schema recWithInt = Schema.createRecord("WithInt", null, "test", false); + recWithInt.setFields(Collections.singletonList(new Schema.Field("x", Schema.create(Schema.Type.INT)))); + assertEquals(1, GenericDatumReader.minBytesPerElement(recWithInt)); + + // Record with only null fields = 0 bytes + Schema recWithNull = Schema.createRecord("WithNull", null, "test", false); + recWithNull.setFields(Collections.singletonList(new Schema.Field("n", Schema.create(Schema.Type.NULL)))); + assertEquals(0, GenericDatumReader.minBytesPerElement(recWithNull)); + + Schema recWithMultipleFields = Schema.createRecord("WithMultipleFields", null, "test", false); + recWithMultipleFields.setFields(Arrays.asList(new Schema.Field("f", Schema.create(Schema.Type.FLOAT)), + new Schema.Field("d", Schema.create(Schema.Type.DOUBLE)))); + assertEquals(12, GenericDatumReader.minBytesPerElement(recWithMultipleFields)); + } + + @Test + void testMinBytesPerElementNestedCollections() { + // Array and map types are >= 1 byte (count varint) + assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.createArray(Schema.create(Schema.Type.INT)))); + assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.createMap(Schema.create(Schema.Type.INT)))); + } + + // --- Collection byte validation end-to-end tests --- + + /** + * Encodes the given longs as Avro varints into a byte array. + */ + private static byte[] encodeVarints(long... values) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + BinaryEncoder enc = EncoderFactory.get().directBinaryEncoder(baos, null); + for (long v : values) { + enc.writeLong(v); + } + enc.flush(); + return baos.toByteArray(); + } + + /** + * Verify that reading an array of ints with a huge count but no element data + * throws EOFException from the schema-aware byte check. + */ + @Test + void arrayOfIntsRejectsHugeCount() throws Exception { + Schema schema = Schema.createArray(Schema.create(Schema.Type.INT)); + GenericDatumReader reader = new GenericDatumReader<>(schema); + + // Binary: varint(10_000_000) for block count, varint(0) for terminator. + // No actual element data -- the reader should reject before allocating. + byte[] data = encodeVarints(10_000_000L, 0L); + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(data, null); + assertThrows(EOFException.class, () -> reader.read(null, decoder)); + } + + /** + * Verify that reading an array of nulls with a large count SUCCEEDS because + * null elements are 0 bytes each, so the byte check is correctly skipped. + */ + @Test + void arrayOfNullsAcceptsLargeCount() throws Exception { + Schema schema = Schema.createArray(Schema.create(Schema.Type.NULL)); + GenericDatumReader reader = new GenericDatumReader<>(schema); + + // Binary: varint(1000) for block count, varint(0) for terminator. + // 1000 null elements = 0 bytes of element data. + byte[] data = encodeVarints(1000L, 0L); + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(data, null); + GenericData.Array result = (GenericData.Array) reader.read(null, decoder); + assertEquals(1000, result.size()); + } + + /** + * Verify that reading a map of string->int with a huge count throws + * EOFException. Each map entry needs at least 2 bytes (1 for key length varint + * + 1 for int value). + */ + @Test + void mapOfStringToIntRejectsHugeCount() throws Exception { + Schema schema = Schema.createMap(Schema.create(Schema.Type.INT)); + GenericDatumReader reader = new GenericDatumReader<>(schema); + + byte[] data = encodeVarints(10_000_000L, 0L); + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(data, null); + assertThrows(EOFException.class, () -> reader.read(null, decoder)); + } + + /** + * Verify that reading a map of string->null with a huge count also throws + * EOFException because map keys are always strings (at least 1 byte each). + */ + @Test + void mapOfStringToNullRejectsHugeCount() throws Exception { + Schema schema = Schema.createMap(Schema.create(Schema.Type.NULL)); + GenericDatumReader reader = new GenericDatumReader<>(schema); + + byte[] data = encodeVarints(10_000_000L, 0L); + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(data, null); + assertThrows(EOFException.class, () -> reader.read(null, decoder)); + } + + /** + * Verify that reading an array of zero-length fixed elements with a large count + * SUCCEEDS because zero-length fixed elements are 0 bytes each. + */ + @Test + void arrayOfZeroLengthFixedAcceptsLargeCount() throws Exception { + Schema fixedSchema = Schema.createFixed("Empty", null, "test", 0); + Schema schema = Schema.createArray(fixedSchema); + GenericDatumReader reader = new GenericDatumReader<>(schema); + + byte[] data = encodeVarints(500L, 0L); + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(data, null); + GenericData.Array result = (GenericData.Array) reader.read(null, decoder); + assertEquals(500, result.size()); + } + + @Test + void arrayOfRecordsRejectsHugeCountUsingFullRecordSize() throws Exception { + Schema recordSchema = Schema.createRecord("Element", null, "test", false); + recordSchema.setFields(Arrays.asList(new Schema.Field("f", Schema.create(Schema.Type.FLOAT)), + new Schema.Field("d", Schema.create(Schema.Type.DOUBLE)))); + Schema schema = Schema.createArray(recordSchema); + GenericDatumReader reader = new GenericDatumReader<>(schema); + + byte[] data = encodeVarints(2L, 0L); + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(data, null); + assertThrows(EOFException.class, () -> reader.read(null, decoder)); + } + + @Test + void mapOfRecordsRejectsHugeCountUsingFullRecordSize() throws Exception { + Schema recordSchema = Schema.createRecord("MapValue", null, "test", false); + recordSchema.setFields(Arrays.asList(new Schema.Field("f", Schema.create(Schema.Type.FLOAT)), + new Schema.Field("d", Schema.create(Schema.Type.DOUBLE)))); + Schema schema = Schema.createMap(recordSchema); + GenericDatumReader reader = new GenericDatumReader<>(schema); + + byte[] data = encodeVarints(1L, 0L); + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(data, null); + assertThrows(EOFException.class, () -> reader.read(null, decoder)); + } } diff --git a/lang/java/avro/src/test/java/org/apache/avro/io/TestBinaryDecoder.java b/lang/java/avro/src/test/java/org/apache/avro/io/TestBinaryDecoder.java index 491151b849e..92d2b921a55 100644 --- a/lang/java/avro/src/test/java/org/apache/avro/io/TestBinaryDecoder.java +++ b/lang/java/avro/src/test/java/org/apache/avro/io/TestBinaryDecoder.java @@ -417,6 +417,51 @@ public void testStringMaxCustom(boolean useDirect) throws IOException { } } + /** + * Verify that a byte-array-backed decoder rejects a string whose varint length + * exceeds the remaining bytes, throwing {@link EOFException} before + * allocating the buffer. + */ + @Test + public void testStringLengthExceedsAvailableBytes() throws IOException { + // Encode a varint claiming 10_000_000 bytes of string data, but supply none. + // The byte-array-backed decoder knows it has only a few bytes left after + // the varint, so ensureAvailableBytes must throw EOFException. + BinaryDecoder bd = newDecoder(false, 10_000_000L); + Assertions.assertThrows(EOFException.class, () -> bd.readString(null)); + } + + /** + * Same as {@link #testStringLengthExceedsAvailableBytes()} but for + * {@link BinaryDecoder#readBytes(ByteBuffer)}. + */ + @Test + public void testBytesLengthExceedsAvailableBytes() throws IOException { + BinaryDecoder bd = newDecoder(false, 10_000_000L); + Assertions.assertThrows(EOFException.class, () -> bd.readBytes(null)); + } + + @Test + public void testStringLengthDoesNotTrustUnknownAvailable() throws IOException { + byte[] encoded; + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + BinaryEncoder encoder = EncoderFactory.get().binaryEncoder(baos, null); + encoder.writeString("hello"); + encoder.flush(); + encoded = baos.toByteArray(); + } + + InputStream in = new ByteArrayInputStream(encoded) { + @Override + public synchronized int available() { + return 0; + } + }; + + BinaryDecoder decoder = factory.binaryDecoder(in, null); + Assertions.assertEquals("hello", decoder.readString(null).toString()); + } + @ParameterizedTest @ValueSource(booleans = { true, false }) public void testBytesNegativeLength(boolean useDirect) throws IOException { From 3b6e1aeeb6e84896d6e85423e73cbf919132f1cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Thu, 30 Apr 2026 20:43:17 +0200 Subject: [PATCH 2/3] AVRO-4241: [Java] Simplify no-op validatioin code and fix possible int overflow per PR comments --- .../avro/generic/GenericDatumReader.java | 32 +++++++++---------- .../org/apache/avro/io/BinaryDecoder.java | 3 +- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java index e09dc95d2dd..178ca50ccad 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java +++ b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java @@ -324,9 +324,7 @@ protected Object readArray(Object old, Schema expected, ResolvingDecoder in) thr */ private long arrayNext(ResolvingDecoder in, Schema elementType) throws IOException { long l = in.arrayNext(); - if (l > 0) { - ensureAvailableCollectionBytes(in, l, elementType); - } + ensureAvailableCollectionBytes(in, l, elementType); return l; } @@ -365,9 +363,7 @@ protected Object readMap(Object old, Schema expected, ResolvingDecoder in) throw long l = in.readMapStart(); LogicalType logicalType = eValue.getLogicalType(); Conversion conversion = getData().getConversionFor(logicalType); - if (l > 0) { - ensureAvailableMapBytes(in, l, eValue); - } + ensureAvailableMapBytes(in, l, eValue); Object map = newMap(old, (int) l); if (l > 0) { do { @@ -392,9 +388,7 @@ protected Object readMap(Object old, Schema expected, ResolvingDecoder in) throw */ private long mapNext(ResolvingDecoder in, Schema valueType) throws IOException { long l = in.mapNext(); - if (l > 0) { - ensureAvailableMapBytes(in, l, valueType); - } + ensureAvailableMapBytes(in, l, valueType); return l; } @@ -404,14 +398,15 @@ private long mapNext(ResolvingDecoder in, Schema valueType) throws IOException { * per entry is {@code 1 + minBytesPerElement(valueSchema)}. */ private static void ensureAvailableMapBytes(Decoder decoder, long count, Schema valueSchema) throws EOFException { + if (count <= 0) { + return; + } // Map keys are always strings: at least 1 byte for the length varint - int minBytesPerEntry = 1 + minBytesPerElement(valueSchema); - if (count > 0) { - int remaining = decoder.remainingBytes(); - if (remaining >= 0 && count * (long) minBytesPerEntry > remaining) { - throw new EOFException("Map claims " + count + " entries with at least " + minBytesPerEntry - + " bytes each, but only " + remaining + " bytes are available"); - } + long minBytesPerEntry = 1L + minBytesPerElement(valueSchema); + int remaining = decoder.remainingBytes(); + if (remaining >= 0 && count * minBytesPerEntry > remaining) { + throw new EOFException("Map claims " + count + " entries with at least " + minBytesPerEntry + + " bytes each, but only " + remaining + " bytes are available"); } } @@ -490,8 +485,11 @@ private static int minBytesPerElement(Schema schema, Set visited) { */ private static void ensureAvailableCollectionBytes(Decoder decoder, long count, Schema elementSchema) throws EOFException { + if (count <= 0) { + return; + } int minBytes = minBytesPerElement(elementSchema); - if (minBytes > 0 && count > 0) { + if (minBytes > 0) { int remaining = decoder.remainingBytes(); if (remaining >= 0 && count * (long) minBytes > remaining) { throw new EOFException("Collection claims " + count + " elements with at least " + minBytes diff --git a/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java b/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java index 22d86ca6504..77fc8490764 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java +++ b/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java @@ -961,7 +961,8 @@ protected int remainingBytes() { int buffered = ba.getLim() - ba.getPos(); try { if (in.getClass() == ByteArrayInputStream.class || in.getClass() == ByteBufferInputStream.class) { - return buffered + in.available(); + long total = (long) buffered + in.available(); + return (int) Math.min(total, Integer.MAX_VALUE); } } catch (IOException e) { return -1; From e0a3bd78301528a9018d349218faf8ec21091436 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Thu, 30 Apr 2026 20:59:07 +0200 Subject: [PATCH 3/3] AVRO-4241: Test the overhead of {@code minBytesPerElement} computation on GenericDatumReader Benchmark to measure the overhead of {@code minBytesPerElement} computation during array/map decoding via {@link GenericDatumReader}. Tests complex, wide, and recursive schema structures to verify that the per-block schema traversal cost is acceptable. --- .../test/generic/MinBytesPerElementTest.java | 287 ++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 lang/java/perf/src/main/java/org/apache/avro/perf/test/generic/MinBytesPerElementTest.java diff --git a/lang/java/perf/src/main/java/org/apache/avro/perf/test/generic/MinBytesPerElementTest.java b/lang/java/perf/src/main/java/org/apache/avro/perf/test/generic/MinBytesPerElementTest.java new file mode 100644 index 00000000000..62c540e4415 --- /dev/null +++ b/lang/java/perf/src/main/java/org/apache/avro/perf/test/generic/MinBytesPerElementTest.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.avro.perf.test.generic; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Random; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.concurrent.TimeUnit; + +/** + * Benchmarks to measure the overhead of {@code minBytesPerElement} computation + * during array/map decoding via {@link GenericDatumReader}. Tests complex, + * wide, and recursive schema structures to verify that the per-block schema + * traversal cost is acceptable. + */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 3) +@Fork(1) +@State(Scope.Thread) +public class MinBytesPerElementTest { + + /** + * Schema complexity levels to benchmark: - WIDE: Record with 50 fields of + * various types - DEEP: Deeply nested records (10 levels) - RECURSIVE: + * Self-referencing record (linked-list style) + */ + @Param({ "WIDE", "DEEP", "RECURSIVE" }) + public String schemaType; + + private Schema arraySchema; + private Schema mapSchema; + private byte[] encodedArrayData; + private byte[] encodedMapData; + private DatumReader arrayReader; + private DatumReader mapReader; + private Schema arrayWrapperSchema; + private Schema mapWrapperSchema; + + @Setup(Level.Trial) + public void setup() throws IOException { + Schema elementSchema = buildElementSchema(schemaType); + + // Wrap in array and map schemas for testing collection decoding paths + arrayWrapperSchema = Schema.createRecord("ArrayWrapper", null, "test", false); + arrayWrapperSchema.setFields(List.of(new Schema.Field("items", Schema.createArray(elementSchema), null, null))); + + mapWrapperSchema = Schema.createRecord("MapWrapper", null, "test", false); + mapWrapperSchema.setFields(List.of(new Schema.Field("entries", Schema.createMap(elementSchema), null, null))); + + arrayReader = new GenericDatumReader<>(arrayWrapperSchema); + mapReader = new GenericDatumReader<>(mapWrapperSchema); + + // Encode test data: array with 1000 elements + encodedArrayData = encodeArrayData(arrayWrapperSchema, elementSchema, 1000); + // Encode test data: map with 1000 entries + encodedMapData = encodeMapData(mapWrapperSchema, elementSchema, 1000); + } + + @Benchmark + public void decodeArrayOfComplexRecords(Blackhole bh) throws IOException { + Decoder decoder = DecoderFactory.get().binaryDecoder(encodedArrayData, null); + GenericRecord result = arrayReader.read(null, decoder); + bh.consume(result); + } + + @Benchmark + public void decodeMapOfComplexRecords(Blackhole bh) throws IOException { + Decoder decoder = DecoderFactory.get().binaryDecoder(encodedMapData, null); + GenericRecord result = mapReader.read(null, decoder); + bh.consume(result); + } + + private static Schema buildElementSchema(String type) { + switch (type) { + case "WIDE": + return buildWideSchema(); + case "DEEP": + return buildDeepSchema(); + case "RECURSIVE": + return buildRecursiveSchema(); + default: + throw new IllegalArgumentException("Unknown schema type: " + type); + } + } + + /** + * Wide record: 50 fields of mixed types (int, long, double, float, string, + * boolean, bytes, nested record). + */ + private static Schema buildWideSchema() { + Schema innerRecord = Schema.createRecord("Inner", null, "test", false); + innerRecord.setFields(List.of(new Schema.Field("x", Schema.create(Schema.Type.INT), null, null), + new Schema.Field("y", Schema.create(Schema.Type.DOUBLE), null, null))); + + List fields = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + fields.add(new Schema.Field("int_" + i, Schema.create(Schema.Type.INT), null, null)); + } + for (int i = 0; i < 10; i++) { + fields.add(new Schema.Field("long_" + i, Schema.create(Schema.Type.LONG), null, null)); + } + for (int i = 0; i < 10; i++) { + fields.add(new Schema.Field("double_" + i, Schema.create(Schema.Type.DOUBLE), null, null)); + } + for (int i = 0; i < 10; i++) { + fields.add(new Schema.Field("str_" + i, Schema.create(Schema.Type.STRING), null, null)); + } + for (int i = 0; i < 5; i++) { + fields.add(new Schema.Field("bool_" + i, Schema.create(Schema.Type.BOOLEAN), null, null)); + } + for (int i = 0; i < 5; i++) { + fields.add(new Schema.Field("rec_" + i, innerRecord, null, null)); + } + + Schema wide = Schema.createRecord("WideRecord", null, "test", false); + wide.setFields(fields); + return wide; + } + + /** + * Deeply nested: 10 levels of records, each containing an int and the next + * level. + */ + private static Schema buildDeepSchema() { + Schema current = Schema.createRecord("Level10", null, "test", false); + current.setFields(List.of(new Schema.Field("value", Schema.create(Schema.Type.INT), null, null))); + + for (int i = 9; i >= 1; i--) { + Schema next = Schema.createRecord("Level" + i, null, "test", false); + next.setFields(List.of(new Schema.Field("value", Schema.create(Schema.Type.INT), null, null), + new Schema.Field("child", current, null, null))); + current = next; + } + return current; + } + + /** + * Recursive: self-referencing record (linked list). The "next" field is a union + * of null and the record itself. + */ + private static Schema buildRecursiveSchema() { + Schema recursive = Schema.createRecord("LinkedNode", null, "test", false); + Schema nullSchema = Schema.create(Schema.Type.NULL); + Schema union = Schema.createUnion(List.of(nullSchema, recursive)); + recursive.setFields(List.of(new Schema.Field("value", Schema.create(Schema.Type.INT), null, null), + new Schema.Field("next", union, null, null))); + return recursive; + } + + private static byte[] encodeArrayData(Schema wrapperSchema, Schema elementSchema, int count) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DatumWriter writer = new GenericDatumWriter<>(wrapperSchema); + Encoder encoder = EncoderFactory.get().binaryEncoder(baos, null); + + GenericRecord wrapper = new GenericData.Record(wrapperSchema); + List items = new ArrayList<>(count); + Random r = new Random(42); + for (int i = 0; i < count; i++) { + items.add(buildRecord(elementSchema, r, 0)); + } + wrapper.put("items", items); + writer.write(wrapper, encoder); + encoder.flush(); + return baos.toByteArray(); + } + + private static byte[] encodeMapData(Schema wrapperSchema, Schema elementSchema, int count) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DatumWriter writer = new GenericDatumWriter<>(wrapperSchema); + Encoder encoder = EncoderFactory.get().binaryEncoder(baos, null); + + GenericRecord wrapper = new GenericData.Record(wrapperSchema); + HashMap map = new HashMap<>(); + Random r = new Random(42); + for (int i = 0; i < count; i++) { + map.put("key_" + i, buildRecord(elementSchema, r, 0)); + } + wrapper.put("entries", map); + writer.write(wrapper, encoder); + encoder.flush(); + return baos.toByteArray(); + } + + private static GenericRecord buildRecord(Schema schema, Random r, int depth) { + GenericRecord rec = new GenericData.Record(schema); + for (Schema.Field field : schema.getFields()) { + rec.put(field.pos(), buildValue(field.schema(), r, depth)); + } + return rec; + } + + private static Object buildValue(Schema schema, Random r, int depth) { + switch (schema.getType()) { + case INT: + return r.nextInt(); + case LONG: + return r.nextLong(); + case FLOAT: + return r.nextFloat(); + case DOUBLE: + return r.nextDouble(); + case BOOLEAN: + return r.nextBoolean(); + case STRING: + return "s" + r.nextInt(1000); + case BYTES: + byte[] b = new byte[4]; + r.nextBytes(b); + return java.nio.ByteBuffer.wrap(b); + case RECORD: + return buildRecord(schema, r, depth + 1); + case UNION: + // For recursive schemas, limit depth + List types = schema.getTypes(); + if (depth > 3) { + // Pick the null branch to terminate recursion + for (int i = 0; i < types.size(); i++) { + if (types.get(i).getType() == Schema.Type.NULL) { + return null; + } + } + } + // Pick non-null branch + for (int i = 0; i < types.size(); i++) { + if (types.get(i).getType() != Schema.Type.NULL) { + return buildValue(types.get(i), r, depth); + } + } + return null; + case ARRAY: + return new ArrayList<>(); + case MAP: + return new HashMap<>(); + case NULL: + return null; + default: + return null; + } + } +}