diff --git a/dataset/pom.xml b/dataset/pom.xml index 6bca75bdd0..2d582268a6 100644 --- a/dataset/pom.xml +++ b/dataset/pom.xml @@ -33,7 +33,7 @@ under the License. ../../../cpp/release-build/ 1.16.0 - 1.12.0 + 1.12.1 diff --git a/docs/source/jdbc.rst b/docs/source/jdbc.rst index c0477cb06d..a4c95dbf00 100644 --- a/docs/source/jdbc.rst +++ b/docs/source/jdbc.rst @@ -213,7 +213,8 @@ Type Mapping ------------ The Arrow to JDBC type mapping can be obtained at runtime via -a method on ColumnBinder. +a method on ColumnBinder. The Flight SQL JDBC driver follows the same +mapping, with additional support for the UUID extension type noted below. +----------------------------+----------------------------+-------+ | Arrow Type | JDBC Type | Notes | @@ -232,6 +233,8 @@ a method on ColumnBinder. +----------------------------+----------------------------+-------+ | FixedSizeBinary | BINARY (setBytes) | | +----------------------------+----------------------------+-------+ +| Uuid (extension) | OTHER (setObject) | \(3) | ++----------------------------+----------------------------+-------+ | Float32 | REAL (setFloat) | | +----------------------------+----------------------------+-------+ | Int8 | TINYINT (setByte) | | @@ -276,3 +279,6 @@ a method on ColumnBinder. `_, which will lead to the driver using the "default timezone" (that of the Java VM). +* \(3) For the Flight SQL JDBC driver, the Arrow UUID extension type + (``arrow.uuid``) maps to JDBC ``OTHER`` and is surfaced as + ``java.util.UUID`` values. diff --git a/flight/flight-sql-jdbc-core/pom.xml b/flight/flight-sql-jdbc-core/pom.xml index 965e071e72..8801ad8178 100644 --- a/flight/flight-sql-jdbc-core/pom.xml +++ b/flight/flight-sql-jdbc-core/pom.xml @@ -151,7 +151,7 @@ under the License. com.github.ben-manes.caffeine caffeine - 3.2.0 + 3.2.3 diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java index 7185ddfe01..502270e1cd 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java @@ -75,10 +75,12 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.extension.UuidType; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.Text; @@ -164,6 +166,9 @@ public class ArrowDatabaseMetadata extends AvaticaDatabaseMetaData { LONGNVARCHAR, SqlSupportsConvert.SQL_CONVERT_LONGVARCHAR_VALUE); sqlTypesToFlightEnumConvertTypes.put(DATE, SqlSupportsConvert.SQL_CONVERT_DATE_VALUE); sqlTypesToFlightEnumConvertTypes.put(TIMESTAMP, SqlSupportsConvert.SQL_CONVERT_TIMESTAMP_VALUE); + + // Register the UUID extension type so it is always available for the driver + ExtensionTypeRegistry.register(UuidType.INSTANCE); } ArrowDatabaseMetadata(final AvaticaConnection connection) { diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java index bbfe88a78a..8362eb7627 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java @@ -19,6 +19,7 @@ import java.util.function.IntSupplier; import org.apache.arrow.driver.jdbc.accessor.impl.ArrowFlightJdbcNullVectorAccessor; import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcBinaryVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcUuidVectorAccessor; import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDurationVectorAccessor; import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcIntervalVectorAccessor; @@ -65,6 +66,7 @@ import org.apache.arrow.vector.UInt2Vector; import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.UuidVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; @@ -138,6 +140,9 @@ public static ArrowFlightJdbcAccessor createAccessor( } else if (vector instanceof LargeVarBinaryVector) { return new ArrowFlightJdbcBinaryVectorAccessor( (LargeVarBinaryVector) vector, getCurrentRow, setCursorWasNull); + } else if (vector instanceof UuidVector) { + return new ArrowFlightJdbcUuidVectorAccessor( + (UuidVector) vector, getCurrentRow, setCursorWasNull); } else if (vector instanceof FixedSizeBinaryVector) { return new ArrowFlightJdbcBinaryVectorAccessor( (FixedSizeBinaryVector) vector, getCurrentRow, setCursorWasNull); diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcUuidVectorAccessor.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcUuidVectorAccessor.java new file mode 100644 index 0000000000..4bdbcbb63d --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcUuidVectorAccessor.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.accessor.impl.binary; + +import java.util.UUID; +import java.util.function.IntSupplier; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.util.UuidUtility; + +/** + * Accessor for the Arrow UUID extension type ({@link UuidVector}). + * + *

This accessor provides JDBC-compatible access to UUID values stored in Arrow's canonical UUID + * extension type ('arrow.uuid'). It follows PostgreSQL JDBC driver conventions: + * + *

+ */ +public class ArrowFlightJdbcUuidVectorAccessor extends ArrowFlightJdbcAccessor { + + private final UuidVector vector; + + /** + * Creates a new accessor for a UUID vector. + * + * @param vector the UUID vector to access + * @param currentRowSupplier supplier for the current row index + * @param setCursorWasNull consumer to set the wasNull flag + */ + public ArrowFlightJdbcUuidVectorAccessor( + UuidVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + public Object getObject() { + UUID uuid = vector.getObject(getCurrentRow()); + this.wasNull = uuid == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return uuid; + } + + @Override + public Class getObjectClass() { + return UUID.class; + } + + @Override + public String getString() { + UUID uuid = (UUID) getObject(); + if (uuid == null) { + return null; + } + return uuid.toString(); + } + + @Override + public byte[] getBytes() { + UUID uuid = (UUID) getObject(); + if (uuid == null) { + return null; + } + return UuidUtility.getBytesFromUUID(uuid); + } +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/UuidAvaticaParameterConverter.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/UuidAvaticaParameterConverter.java new file mode 100644 index 0000000000..b2157890cf --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/UuidAvaticaParameterConverter.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.converter.impl; + +import static org.apache.arrow.driver.jdbc.utils.SqlTypes.getSqlTypeIdFromArrowType; +import static org.apache.arrow.driver.jdbc.utils.SqlTypes.getSqlTypeNameFromArrowType; + +import java.nio.ByteBuffer; +import java.util.UUID; +import org.apache.arrow.driver.jdbc.converter.AvaticaParameterConverter; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.UuidUtility; +import org.apache.calcite.avatica.AvaticaParameter; +import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.avatica.util.ByteString; + +/** + * AvaticaParameterConverter for UUID Arrow extension type. + * + *

Handles conversion of UUID values from JDBC parameters to Arrow's UUID extension type. Accepts + * both {@link UUID} objects and String representations of UUIDs. + */ +public class UuidAvaticaParameterConverter implements AvaticaParameterConverter { + + public UuidAvaticaParameterConverter() {} + + @Override + public boolean bindParameter(FieldVector vector, TypedValue typedValue, int index) { + if (!(vector instanceof UuidVector)) { + return false; + } + + UuidVector uuidVector = (UuidVector) vector; + Object value = typedValue.toJdbc(null); + + if (value == null) { + uuidVector.setNull(index); + return true; + } + + UUID uuid; + if (value instanceof UUID) { + uuid = (UUID) value; + } else if (value instanceof String) { + uuid = UUID.fromString((String) value); + } else if (value instanceof byte[]) { + byte[] bytes = (byte[]) value; + if (bytes.length != 16) { + throw new IllegalArgumentException("UUID byte array must be 16 bytes, got " + bytes.length); + } + uuid = uuidFromBytes(bytes); + } else if (value instanceof ByteString) { + byte[] bytes = ((ByteString) value).getBytes(); + if (bytes.length != 16) { + throw new IllegalArgumentException("UUID byte array must be 16 bytes, got " + bytes.length); + } + uuid = uuidFromBytes(bytes); + } else { + throw new IllegalArgumentException( + "Cannot convert " + value.getClass().getName() + " to UUID"); + } + + uuidVector.setSafe(index, UuidUtility.getBytesFromUUID(uuid)); + return true; + } + + @Override + public AvaticaParameter createParameter(Field field) { + final String name = field.getName(); + final int jdbcType = getSqlTypeIdFromArrowType(field.getType()); + final String typeName = getSqlTypeNameFromArrowType(field.getType()); + final String className = UUID.class.getCanonicalName(); + return new AvaticaParameter(false, 0, 0, jdbcType, typeName, className, name); + } + + private static UUID uuidFromBytes(byte[] bytes) { + final long mostSignificantBits; + final long leastSignificantBits; + ByteBuffer bb = ByteBuffer.wrap(bytes); + // Reads the first eight bytes + mostSignificantBits = bb.getLong(); + // Reads the first eight bytes at this buffer's current + leastSignificantBits = bb.getLong(); + + return new UUID(mostSignificantBits, leastSignificantBits); + } +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java index 8c98ee4077..8f40d6698e 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java @@ -41,10 +41,14 @@ import org.apache.arrow.driver.jdbc.converter.impl.UnionAvaticaParameterConverter; import org.apache.arrow.driver.jdbc.converter.impl.Utf8AvaticaParameterConverter; import org.apache.arrow.driver.jdbc.converter.impl.Utf8ViewAvaticaParameterConverter; +import org.apache.arrow.driver.jdbc.converter.impl.UuidAvaticaParameterConverter; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.extension.UuidType; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeVisitor; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; import org.apache.calcite.avatica.remote.TypedValue; import org.checkerframework.checker.nullness.qual.Nullable; @@ -290,5 +294,15 @@ public Boolean visit(ArrowType.RunEndEncoded type) { throw new UnsupportedOperationException( "No Avatica parameter binder implemented for type " + type); } + + @Override + public Boolean visit(ExtensionType type) { + if (type instanceof UuidType) { + return new UuidAvaticaParameterConverter().bindParameter(vector, typedValue, index); + } + + // fallback to default implementation + return ArrowTypeVisitor.super.visit(type); + } } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java index 5dd4c69c73..dd51ee5361 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java @@ -43,8 +43,12 @@ import org.apache.arrow.driver.jdbc.converter.impl.UnionAvaticaParameterConverter; import org.apache.arrow.driver.jdbc.converter.impl.Utf8AvaticaParameterConverter; import org.apache.arrow.driver.jdbc.converter.impl.Utf8ViewAvaticaParameterConverter; +import org.apache.arrow.driver.jdbc.converter.impl.UuidAvaticaParameterConverter; import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.vector.extension.UuidType; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeVisitor; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.calcite.avatica.AvaticaParameter; import org.apache.calcite.avatica.ColumnMetaData; @@ -294,5 +298,15 @@ public AvaticaParameter visit(ArrowType.RunEndEncoded type) { throw new UnsupportedOperationException( "No Avatica parameter binder implemented for type " + type); } + + @Override + public AvaticaParameter visit(ExtensionType type) { + if (type instanceof UuidType) { + return new UuidAvaticaParameterConverter().createParameter(field); + } + + // fallback to default implementation + return ArrowTypeVisitor.super.visit(type); + } } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java index 5ba3957f8b..7982d5bc73 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java @@ -20,11 +20,13 @@ import java.sql.Types; import java.util.HashMap; import java.util.Map; +import org.apache.arrow.vector.extension.UuidType; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; /** SQL Types utility functions. */ public class SqlTypes { + private static final Map typeIdToName = new HashMap<>(); static { @@ -110,6 +112,9 @@ public static int getSqlTypeIdFromArrowType(ArrowType arrowType) { case BinaryView: return Types.VARBINARY; case FixedSizeBinary: + if (arrowType instanceof UuidType) { + return Types.OTHER; + } return Types.BINARY; case LargeBinary: return Types.LONGVARBINARY; diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java index 569b5495fe..3a5a39be3d 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -22,8 +22,10 @@ import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.anyOf; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.*; @@ -31,7 +33,9 @@ import java.nio.charset.StandardCharsets; import java.sql.Connection; import java.sql.DriverManager; +import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.SQLTimeoutException; import java.sql.Statement; @@ -42,6 +46,7 @@ import java.util.List; import java.util.Random; import java.util.Set; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; import org.apache.arrow.driver.jdbc.utils.FallbackFlightSqlProducer; @@ -61,6 +66,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.UuidUtility; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -795,4 +801,174 @@ public void testResultSetAppMetadata() throws Exception { "foo".getBytes(StandardCharsets.UTF_8)); } } + + @Test + public void testSelectQueryWithUuidColumn() throws SQLException { + // Expectations + final int expectedRowCount = 4; + final UUID[] expectedUuids = + new UUID[] { + CoreMockedSqlProducers.UUID_1, + CoreMockedSqlProducers.UUID_2, + CoreMockedSqlProducers.UUID_3, + null + }; + + final Integer[] expectedIds = new Integer[] {1, 2, 3, 4}; + + final List actualUuids = new ArrayList<>(expectedRowCount); + final List actualIds = new ArrayList<>(expectedRowCount); + + // Query + int actualRowCount = 0; + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + for (; resultSet.next(); actualRowCount++) { + actualIds.add((Integer) resultSet.getObject("id")); + actualUuids.add((UUID) resultSet.getObject("uuid_col")); + } + } + + // Assertions + int finalActualRowCount = actualRowCount; + assertAll( + "UUID ResultSet values are as expected", + () -> assertThat(finalActualRowCount, is(equalTo(expectedRowCount))), + () -> assertThat(actualIds.toArray(new Integer[0]), is(expectedIds)), + () -> assertThat(actualUuids.toArray(new UUID[0]), is(expectedUuids))); + } + + @Test + public void testGetObjectReturnsUuid() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + resultSet.next(); + Object result = resultSet.getObject("uuid_col"); + assertThat(result, instanceOf(UUID.class)); + assertThat(result, is(CoreMockedSqlProducers.UUID_1)); + } + } + + @Test + public void testGetObjectByIndexReturnsUuid() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + resultSet.next(); + Object result = resultSet.getObject(2); + assertThat(result, instanceOf(UUID.class)); + assertThat(result, is(CoreMockedSqlProducers.UUID_1)); + } + } + + @Test + public void testGetStringReturnsHyphenatedFormat() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + resultSet.next(); + String result = resultSet.getString("uuid_col"); + assertThat(result, is(CoreMockedSqlProducers.UUID_1.toString())); + } + } + + @Test + public void testGetBytesReturns16ByteArray() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + resultSet.next(); + byte[] result = resultSet.getBytes("uuid_col"); + assertThat(result.length, is(16)); + assertThat(result, is(UuidUtility.getBytesFromUUID(CoreMockedSqlProducers.UUID_1))); + } + } + + @Test + public void testNullUuidHandling() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + // Skip to row 4 which has NULL UUID + resultSet.next(); // row 1 + resultSet.next(); // row 2 + resultSet.next(); // row 3 + resultSet.next(); // row 4 (NULL UUID) + + Object objResult = resultSet.getObject("uuid_col"); + assertThat(objResult, nullValue()); + assertThat(resultSet.wasNull(), is(true)); + + String strResult = resultSet.getString("uuid_col"); + assertThat(strResult, nullValue()); + assertThat(resultSet.wasNull(), is(true)); + + byte[] bytesResult = resultSet.getBytes("uuid_col"); + assertThat(bytesResult, nullValue()); + assertThat(resultSet.wasNull(), is(true)); + } + } + + @Test + public void testMultipleUuidRows() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + resultSet.next(); + assertThat(resultSet.getObject("uuid_col"), is(CoreMockedSqlProducers.UUID_1)); + + resultSet.next(); + assertThat(resultSet.getObject("uuid_col"), is(CoreMockedSqlProducers.UUID_2)); + + resultSet.next(); + assertThat(resultSet.getObject("uuid_col"), is(CoreMockedSqlProducers.UUID_3)); + + resultSet.next(); + assertThat(resultSet.getObject("uuid_col"), nullValue()); + } + } + + @Test + public void testUuidExtensionTypeInSchema() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.UUID_SQL_CMD)) { + ResultSetMetaData metaData = resultSet.getMetaData(); + + assertThat(metaData.getColumnCount(), is(2)); + assertThat(metaData.getColumnName(1), is("id")); + assertThat(metaData.getColumnName(2), is("uuid_col")); + + assertThat(metaData.getColumnType(2), is(java.sql.Types.OTHER)); + } + } + + @Test + public void testPreparedStatementWithUuidParameter() throws SQLException { + try (PreparedStatement pstmt = + connection.prepareStatement(CoreMockedSqlProducers.UUID_PREPARED_SELECT_SQL_CMD)) { + pstmt.setObject(1, CoreMockedSqlProducers.UUID_1); + try (ResultSet rs = pstmt.executeQuery()) { + rs.next(); + assertThat(rs.getObject("uuid_col"), is(CoreMockedSqlProducers.UUID_1)); + } + } + } + + @Test + public void testPreparedStatementWithUuidStringParameter() throws SQLException { + try (PreparedStatement pstmt = + connection.prepareStatement(CoreMockedSqlProducers.UUID_PREPARED_SELECT_SQL_CMD)) { + pstmt.setString(1, CoreMockedSqlProducers.UUID_1.toString()); + try (ResultSet rs = pstmt.executeQuery()) { + rs.next(); + assertThat(rs.getObject("uuid_col"), is(CoreMockedSqlProducers.UUID_1)); + } + } + } + + @Test + public void testPreparedStatementUpdateWithUuid() throws SQLException { + try (PreparedStatement pstmt = + connection.prepareStatement(CoreMockedSqlProducers.UUID_PREPARED_UPDATE_SQL_CMD)) { + pstmt.setObject(1, CoreMockedSqlProducers.UUID_3); + pstmt.setInt(2, 1); + int updated = pstmt.executeUpdate(); + assertThat(updated, is(1)); + } + } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java index 8b39041f0c..1fbd2f86a9 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java @@ -16,10 +16,12 @@ */ package org.apache.arrow.driver.jdbc.accessor; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.function.IntSupplier; import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcBinaryVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcUuidVectorAccessor; import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDurationVectorAccessor; import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcIntervalVectorAccessor; @@ -497,4 +499,15 @@ public void createAccessorForMapVector() { assertTrue(accessor instanceof ArrowFlightJdbcMapVectorAccessor); } } + + @Test + public void createAccessorForUuidVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createUuidVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor( + valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertInstanceOf(ArrowFlightJdbcUuidVectorAccessor.class, accessor); + } + } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcUuidVectorAccessorTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcUuidVectorAccessorTest.java new file mode 100644 index 0000000000..b7f341240c --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcUuidVectorAccessorTest.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.accessor.impl.binary; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +import java.util.UUID; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestExtension; +import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.util.UuidUtility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +/** + * Tests for {@link ArrowFlightJdbcUuidVectorAccessor}. + * + *

Verifies that the accessor correctly handles UUID values from Arrow's UUID extension type, + * following PostgreSQL JDBC driver conventions. + */ +public class ArrowFlightJdbcUuidVectorAccessorTest { + + @RegisterExtension + public static RootAllocatorTestExtension rootAllocatorTestExtension = + new RootAllocatorTestExtension(); + + private static final UUID UUID_1 = UUID.fromString("550e8400-e29b-41d4-a716-446655440000"); + private static final UUID UUID_2 = UUID.fromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8"); + private static final UUID UUID_3 = UUID.fromString("f47ac10b-58cc-4372-a567-0e02b2c3d479"); + + private UuidVector vector; + private ArrowFlightJdbcUuidVectorAccessor accessor; + private boolean wasNullCalled; + private boolean wasNullValue; + + @BeforeEach + public void setUp() { + vector = rootAllocatorTestExtension.createUuidVector(); + wasNullCalled = false; + wasNullValue = false; + ArrowFlightJdbcAccessorFactory.WasNullConsumer wasNullConsumer = + (wasNull) -> { + wasNullCalled = true; + wasNullValue = wasNull; + }; + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, wasNullConsumer); + } + + @AfterEach + public void tearDown() { + vector.close(); + } + + @Test + public void testGetObjectReturnsUuid() { + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, (wasNull) -> {}); + Object result = accessor.getObject(); + assertThat(result, is(UUID_1)); + assertThat(accessor.wasNull(), is(false)); + } + + @Test + public void testGetObjectReturnsCorrectUuidForEachRow() { + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, (wasNull) -> {}); + assertThat(accessor.getObject(), is(UUID_1)); + + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 1, (wasNull) -> {}); + assertThat(accessor.getObject(), is(UUID_2)); + + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 2, (wasNull) -> {}); + assertThat(accessor.getObject(), is(UUID_3)); + } + + @Test + public void testGetObjectReturnsNullForNullValue() { + vector.reset(); + vector.allocateNew(1); + vector.setNull(0); + vector.setValueCount(1); + + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, (wasNull) -> {}); + Object result = accessor.getObject(); + assertThat(result, nullValue()); + assertThat(accessor.wasNull(), is(true)); + } + + @Test + public void testGetObjectClassReturnsUuidClass() { + assertThat(accessor.getObjectClass(), equalTo(UUID.class)); + } + + @Test + public void testGetStringReturnsHyphenatedFormat() { + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, (wasNull) -> {}); + String result = accessor.getString(); + assertThat(result, is("550e8400-e29b-41d4-a716-446655440000")); + assertThat(accessor.wasNull(), is(false)); + } + + @Test + public void testGetStringReturnsNullForNullValue() { + vector.reset(); + vector.allocateNew(1); + vector.setNull(0); + vector.setValueCount(1); + + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, (wasNull) -> {}); + String result = accessor.getString(); + assertThat(result, nullValue()); + assertThat(accessor.wasNull(), is(true)); + } + + @Test + public void testGetBytesReturns16ByteArray() { + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, (wasNull) -> {}); + byte[] result = accessor.getBytes(); + assertThat(result.length, is(16)); + assertThat(result, is(UuidUtility.getBytesFromUUID(UUID_1))); + assertThat(accessor.wasNull(), is(false)); + } + + @Test + public void testGetBytesReturnsNullForNullValue() { + vector.reset(); + vector.allocateNew(1); + vector.setNull(0); + vector.setValueCount(1); + + accessor = new ArrowFlightJdbcUuidVectorAccessor(vector, () -> 0, (wasNull) -> {}); + byte[] result = accessor.getBytes(); + assertThat(result, nullValue()); + assertThat(accessor.wasNull(), is(true)); + } + + @Test + public void testWasNullConsumerIsCalled() { + accessor = + new ArrowFlightJdbcUuidVectorAccessor( + vector, + () -> 0, + (wasNull) -> { + wasNullCalled = true; + wasNullValue = wasNull; + }); + accessor.getObject(); + assertThat(wasNullCalled, is(true)); + assertThat(wasNullValue, is(false)); + } + + @Test + public void testWasNullConsumerIsCalledWithTrueForNull() { + vector.reset(); + vector.allocateNew(1); + vector.setNull(0); + vector.setValueCount(1); + + accessor = + new ArrowFlightJdbcUuidVectorAccessor( + vector, + () -> 0, + (wasNull) -> { + wasNullCalled = true; + wasNullValue = wasNull; + }); + accessor.getObject(); + assertThat(wasNullCalled, is(true)); + assertThat(wasNullValue, is(true)); + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/converter/impl/UuidAvaticaParameterConverterTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/converter/impl/UuidAvaticaParameterConverterTest.java new file mode 100644 index 0000000000..07751f0abc --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/converter/impl/UuidAvaticaParameterConverterTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.converter.impl; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.sql.Types; +import java.util.UUID; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestExtension; +import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.extension.UuidType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.UuidUtility; +import org.apache.calcite.avatica.AvaticaParameter; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.avatica.util.ByteString; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +/** + * Tests for {@link UuidAvaticaParameterConverter}. + * + *

Verifies that the converter correctly handles UUID parameter binding from JDBC to Arrow's UUID + * extension type. + */ +public class UuidAvaticaParameterConverterTest { + + @RegisterExtension + public static RootAllocatorTestExtension rootAllocatorTestExtension = + new RootAllocatorTestExtension(); + + private static final UUID TEST_UUID = UUID.fromString("550e8400-e29b-41d4-a716-446655440000"); + + private UuidVector vector; + private UuidAvaticaParameterConverter converter; + + @BeforeEach + public void setUp() { + vector = new UuidVector("uuid_param", rootAllocatorTestExtension.getRootAllocator()); + vector.allocateNew(5); + converter = new UuidAvaticaParameterConverter(); + } + + @AfterEach + public void tearDown() { + vector.close(); + } + + @Test + public void testBindParameterWithUuidObject() { + TypedValue typedValue = TypedValue.ofLocal(ColumnMetaData.Rep.OBJECT, TEST_UUID); + + boolean result = converter.bindParameter(vector, typedValue, 0); + + assertTrue(result); + assertThat(vector.getObject(0), is(TEST_UUID)); + } + + @Test + public void testBindParameterWithUuidString() { + String uuidString = "550e8400-e29b-41d4-a716-446655440000"; + TypedValue typedValue = TypedValue.ofLocal(ColumnMetaData.Rep.STRING, uuidString); + + boolean result = converter.bindParameter(vector, typedValue, 0); + + assertTrue(result); + assertThat(vector.getObject(0), is(TEST_UUID)); + } + + @Test + public void testBindParameterWithByteArray() { + byte[] uuidBytes = UuidUtility.getBytesFromUUID(TEST_UUID); + ByteString byteString = new ByteString(uuidBytes); + TypedValue typedValue = TypedValue.ofLocal(ColumnMetaData.Rep.BYTE_STRING, byteString); + + boolean result = converter.bindParameter(vector, typedValue, 0); + + assertTrue(result); + assertThat(vector.getObject(0), is(TEST_UUID)); + } + + @Test + public void testBindParameterWithNullValue() { + TypedValue typedValue = TypedValue.ofLocal(ColumnMetaData.Rep.OBJECT, null); + + boolean result = converter.bindParameter(vector, typedValue, 0); + + assertTrue(result); + assertTrue(vector.isNull(0)); + assertThat(vector.getObject(0), nullValue()); + } + + @Test + public void testBindParameterWithInvalidByteArrayLength() { + byte[] invalidBytes = new byte[8]; // Should be 16 bytes + ByteString byteString = new ByteString(invalidBytes); + TypedValue typedValue = TypedValue.ofLocal(ColumnMetaData.Rep.BYTE_STRING, byteString); + + assertThrows( + IllegalArgumentException.class, () -> converter.bindParameter(vector, typedValue, 0)); + } + + @Test + public void testBindParameterWithInvalidType() { + TypedValue typedValue = TypedValue.ofLocal(ColumnMetaData.Rep.INTEGER, 12345); + + assertThrows( + IllegalArgumentException.class, () -> converter.bindParameter(vector, typedValue, 0)); + } + + @Test + public void testBindParameterMultipleValues() { + UUID uuid1 = UUID.fromString("550e8400-e29b-41d4-a716-446655440000"); + UUID uuid2 = UUID.fromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8"); + UUID uuid3 = UUID.fromString("f47ac10b-58cc-4372-a567-0e02b2c3d479"); + + converter.bindParameter(vector, TypedValue.ofLocal(ColumnMetaData.Rep.OBJECT, uuid1), 0); + converter.bindParameter(vector, TypedValue.ofLocal(ColumnMetaData.Rep.OBJECT, uuid2), 1); + converter.bindParameter(vector, TypedValue.ofLocal(ColumnMetaData.Rep.OBJECT, uuid3), 2); + + assertThat(vector.getObject(0), is(uuid1)); + assertThat(vector.getObject(1), is(uuid2)); + assertThat(vector.getObject(2), is(uuid3)); + } + + @Test + public void testCreateParameter() { + Field uuidField = new Field("uuid_col", new FieldType(true, UuidType.INSTANCE, null), null); + + AvaticaParameter parameter = converter.createParameter(uuidField); + + assertThat(parameter.name, is("uuid_col")); + assertThat(parameter.parameterType, is(Types.OTHER)); + assertThat(parameter.typeName, is("OTHER")); + assertThat(parameter.className, equalTo(UUID.class.getCanonicalName())); + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java index 8197d7d95f..7c17755693 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java @@ -28,8 +28,10 @@ import java.sql.SQLException; import java.sql.Timestamp; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.UUID; import java.util.function.Consumer; import java.util.stream.IntStream; import org.apache.arrow.flight.FlightProducer.ServerStreamListener; @@ -40,10 +42,13 @@ import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UuidVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.extension.UuidType; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; @@ -52,6 +57,7 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.Text; +import org.apache.arrow.vector.util.UuidUtility; /** Standard {@link MockFlightSqlProducer} instances for tests. */ // TODO Remove this once all tests are refactor to use only the queries they need. @@ -62,6 +68,22 @@ public final class CoreMockedSqlProducers { public static final String LEGACY_CANCELLATION_SQL_CMD = "SELECT * FROM TAKES_FOREVER"; public static final String LEGACY_REGULAR_WITH_EMPTY_SQL_CMD = "SELECT * FROM TEST_EMPTIES"; + public static final String UUID_SQL_CMD = "SELECT * FROM UUID_TABLE"; + public static final String UUID_PREPARED_SELECT_SQL_CMD = + "SELECT * FROM UUID_TABLE WHERE uuid_col = ?"; + public static final String UUID_PREPARED_UPDATE_SQL_CMD = + "UPDATE UUID_TABLE SET uuid_col = ? WHERE id = ?"; + + public static final UUID UUID_1 = UUID.fromString("550e8400-e29b-41d4-a716-446655440000"); + public static final UUID UUID_2 = UUID.fromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8"); + public static final UUID UUID_3 = UUID.fromString("f47ac10b-58cc-4372-a567-0e02b2c3d479"); + + public static final Schema UUID_SCHEMA = + new Schema( + ImmutableList.of( + new Field("id", new FieldType(true, new ArrowType.Int(32, true), null), null), + new Field("uuid_col", new FieldType(true, UuidType.INSTANCE, null), null))); + private CoreMockedSqlProducers() { // Prevent instantiation. } @@ -78,9 +100,109 @@ public static MockFlightSqlProducer getLegacyProducer() { addLegacyMetadataSqlCmdSupport(producer); addLegacyCancellationSqlCmdSupport(producer); addQueryWithEmbeddedEmptyRoot(producer); + addUuidSqlCmdSupport(producer); + addUuidPreparedSelectSqlCmdSupport(producer); + addUuidPreparedUpdateSqlCmdSupport(producer); return producer; } + /** + * Gets a {@link MockFlightSqlProducer} configured with UUID test data. + * + * @return a new producer with UUID support. + */ + public static MockFlightSqlProducer getUuidProducer() { + final MockFlightSqlProducer producer = new MockFlightSqlProducer(); + addUuidSqlCmdSupport(producer); + return producer; + } + + private static void addUuidPreparedUpdateSqlCmdSupport(final MockFlightSqlProducer producer) { + final String query = "UPDATE UUID_TABLE SET uuid_col = ? WHERE id = ?"; + final Schema parameterSchema = + new Schema( + Arrays.asList( + new Field("", new FieldType(true, UuidType.INSTANCE, null), null), + Field.nullable("", new ArrowType.Int(32, true)))); + + producer.addUpdateQuery(query, 1); + producer.addExpectedParameters( + UUID_PREPARED_UPDATE_SQL_CMD, + parameterSchema, + Collections.singletonList(Arrays.asList(CoreMockedSqlProducers.UUID_3, 1))); + } + + private static void addUuidPreparedSelectSqlCmdSupport(final MockFlightSqlProducer producer) { + final Schema parameterSchema = + new Schema( + Collections.singletonList( + new Field("", new FieldType(true, UuidType.INSTANCE, null), null))); + + final Consumer uuidResultProvider = + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(UUID_SCHEMA, allocator)) { + root.allocateNew(); + IntVector idVector = (IntVector) root.getVector("id"); + UuidVector uuidVector = (UuidVector) root.getVector("uuid_col"); + idVector.setSafe(0, 1); + uuidVector.setSafe(0, UuidUtility.getBytesFromUUID(CoreMockedSqlProducers.UUID_1)); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + + producer.addSelectQuery( + UUID_PREPARED_SELECT_SQL_CMD, UUID_SCHEMA, Collections.singletonList(uuidResultProvider)); + producer.addExpectedParameters( + UUID_PREPARED_SELECT_SQL_CMD, + parameterSchema, + Collections.singletonList(Collections.singletonList(CoreMockedSqlProducers.UUID_1))); + } + + private static void addUuidSqlCmdSupport(final MockFlightSqlProducer producer) { + final Consumer uuidResultProvider = + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(UUID_SCHEMA, allocator)) { + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + UuidVector uuidVector = (UuidVector) root.getVector("uuid_col"); + + // Row 0: id=1, uuid=UUID_1 + idVector.setSafe(0, 1); + uuidVector.setSafe(0, UuidUtility.getBytesFromUUID(UUID_1)); + + // Row 1: id=2, uuid=UUID_2 + idVector.setSafe(1, 2); + uuidVector.setSafe(1, UuidUtility.getBytesFromUUID(UUID_2)); + + // Row 2: id=3, uuid=UUID_3 + idVector.setSafe(2, 3); + uuidVector.setSafe(2, UuidUtility.getBytesFromUUID(UUID_3)); + + // Row 3: id=4, uuid=NULL + idVector.setSafe(3, 4); + uuidVector.setNull(3); + + root.setRowCount(4); + listener.start(root); + listener.putNext(); + } finally { + listener.completed(); + } + }; + + producer.addSelectQuery( + UUID_SQL_CMD, UUID_SCHEMA, Collections.singletonList(uuidResultProvider)); + } + private static void addQueryWithEmbeddedEmptyRoot(final MockFlightSqlProducer producer) { final Schema querySchema = new Schema( diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestExtension.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestExtension.java index 347e92a16c..4b299d63e0 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestExtension.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestExtension.java @@ -19,6 +19,7 @@ import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.util.Random; +import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; import org.apache.arrow.memory.BufferAllocator; @@ -53,6 +54,7 @@ import org.apache.arrow.vector.UInt2Vector; import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.UuidVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; @@ -60,6 +62,7 @@ import org.apache.arrow.vector.complex.impl.UnionFixedSizeListWriter; import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.util.UuidUtility; import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.BeforeAllCallback; import org.junit.jupiter.api.extension.ExtensionContext; @@ -811,4 +814,23 @@ public FixedSizeListVector createFixedSizeListVector() { return valueVector; } + + /** + * Create a UuidVector to be used in the accessor tests. + * + * @return UuidVector + */ + public UuidVector createUuidVector() { + UuidVector valueVector = new UuidVector("", this.getRootAllocator()); + valueVector.allocateNew(3); + valueVector.setSafe( + 0, UuidUtility.getBytesFromUUID(UUID.fromString("550e8400-e29b-41d4-a716-446655440000"))); + valueVector.setSafe( + 1, UuidUtility.getBytesFromUUID(UUID.fromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))); + valueVector.setSafe( + 2, UuidUtility.getBytesFromUUID(UUID.fromString("f47ac10b-58cc-4372-a567-0e02b2c3d479"))); + valueVector.setValueCount(3); + + return valueVector; + } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java index d69c549296..c4858d787d 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java @@ -21,6 +21,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.sql.Types; +import org.apache.arrow.vector.extension.UuidType; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.IntervalUnit; @@ -85,6 +86,8 @@ public void testGetSqlTypeIdFromArrowType() { assertEquals(Types.JAVA_OBJECT, getSqlTypeIdFromArrowType(new ArrowType.Map(true))); assertEquals(Types.NULL, getSqlTypeIdFromArrowType(new ArrowType.Null())); + + assertEquals(Types.OTHER, getSqlTypeIdFromArrowType(UuidType.INSTANCE)); } @Test @@ -140,5 +143,7 @@ public void testGetSqlTypeNameFromArrowType() { assertEquals("JAVA_OBJECT", getSqlTypeNameFromArrowType(new ArrowType.Map(true))); assertEquals("NULL", getSqlTypeNameFromArrowType(new ArrowType.Null())); + + assertEquals("OTHER", getSqlTypeNameFromArrowType(UuidType.INSTANCE)); } } diff --git a/flight/flight-sql/pom.xml b/flight/flight-sql/pom.xml index 15d00e3e18..4175ff70d3 100644 --- a/flight/flight-sql/pom.xml +++ b/flight/flight-sql/pom.xml @@ -113,7 +113,7 @@ under the License. org.apache.commons commons-text - 1.13.1 + 1.15.0 test diff --git a/pom.xml b/pom.xml index 6d1181d1d3..d6fdfdd477 100644 --- a/pom.xml +++ b/pom.xml @@ -98,12 +98,12 @@ under the License. 2.0.17 33.4.8-jre 4.2.9.Final - 1.73.0 + 1.78.0 4.33.1 2.18.3 3.4.2 25.2.10 - 1.12.0 + 1.12.1 5.17.0 2 @@ -111,7 +111,7 @@ under the License. true 2.42.0 3.53.0 - 1.5.21 + 1.5.24 none -Xdoclint:none @@ -505,7 +505,7 @@ under the License. org.codehaus.mojo exec-maven-plugin - 3.5.0 + 3.6.3 org.codehaus.mojo diff --git a/vector/src/main/codegen/templates/AbstractFieldReader.java b/vector/src/main/codegen/templates/AbstractFieldReader.java index c7c5b4d78d..556fb576ce 100644 --- a/vector/src/main/codegen/templates/AbstractFieldReader.java +++ b/vector/src/main/codegen/templates/AbstractFieldReader.java @@ -109,10 +109,6 @@ public void copyAsField(String name, ${name}Writer writer) { - public void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory) { - fail("CopyAsValue StructWriter"); - } - public void read(ExtensionHolder holder) { fail("Extension"); } @@ -147,4 +143,5 @@ public int size() { private void fail(String name) { throw new IllegalArgumentException(String.format("You tried to read a [%s] type when you are using a field reader of type [%s].", name, this.getClass().getSimpleName())); } + } diff --git a/vector/src/main/codegen/templates/AbstractFieldWriter.java b/vector/src/main/codegen/templates/AbstractFieldWriter.java index ae5b97faef..4b4a17d932 100644 --- a/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -107,14 +107,17 @@ public void endEntry() { throw new IllegalStateException(String.format("You tried to end a map entry when you are using a ValueWriter of type %s.", this.getClass().getSimpleName())); } + @Override public void write(ExtensionHolder var1) { - this.fail("ExtensionType"); + this.fail("Cannot write ExtensionHolder"); } + @Override public void writeExtension(Object var1) { - this.fail("ExtensionType"); + this.fail("Cannot write extension object"); } - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory var1) { - this.fail("ExtensionType"); + @Override + public void writeExtension(Object var1, ArrowType type) { + this.fail("Cannot write extension with type " + type); } <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> diff --git a/vector/src/main/codegen/templates/ArrowType.java b/vector/src/main/codegen/templates/ArrowType.java index fd35c1cd2b..b428f09155 100644 --- a/vector/src/main/codegen/templates/ArrowType.java +++ b/vector/src/main/codegen/templates/ArrowType.java @@ -27,8 +27,10 @@ import org.apache.arrow.flatbuf.Type; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.types.*; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -331,6 +333,10 @@ public boolean equals(Object obj) { public T accept(ArrowTypeVisitor visitor) { return visitor.visit(this); } + + public FieldWriter getNewFieldWriter(ValueVector vector) { + throw new UnsupportedOperationException("WriterImpl not yet implemented."); + } } private static final int defaultDecimalBitWidth = 128; diff --git a/vector/src/main/codegen/templates/BaseReader.java b/vector/src/main/codegen/templates/BaseReader.java index 4c6f49ab9b..c52345af21 100644 --- a/vector/src/main/codegen/templates/BaseReader.java +++ b/vector/src/main/codegen/templates/BaseReader.java @@ -49,7 +49,6 @@ public interface RepeatedStructReader extends StructReader{ boolean next(); int size(); void copyAsValue(StructWriter writer); - void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory); } public interface ListReader extends BaseReader{ @@ -60,7 +59,6 @@ public interface RepeatedListReader extends ListReader{ boolean next(); int size(); void copyAsValue(ListWriter writer); - void copyAsValue(ListWriter writer, ExtensionTypeWriterFactory writerFactory); } public interface MapReader extends BaseReader{ @@ -71,7 +69,6 @@ public interface RepeatedMapReader extends MapReader{ boolean next(); int size(); void copyAsValue(MapWriter writer); - void copyAsValue(MapWriter writer, ExtensionTypeWriterFactory writerFactory); } public interface ScalarReader extends diff --git a/vector/src/main/codegen/templates/BaseWriter.java b/vector/src/main/codegen/templates/BaseWriter.java index 78da7fddc3..a4c98d7089 100644 --- a/vector/src/main/codegen/templates/BaseWriter.java +++ b/vector/src/main/codegen/templates/BaseWriter.java @@ -125,11 +125,12 @@ public interface ExtensionWriter extends BaseWriter { void writeExtension(Object value); /** - * Adds the given extension type factory. This factory allows configuring writer implementations for specific ExtensionTypeVector. + * Writes the given extension type value. * - * @param factory the extension type factory to add + * @param value the extension type value to write + * @param type of the extension */ - void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory); + void writeExtension(Object value, ArrowType type); } public interface ScalarWriter extends diff --git a/vector/src/main/codegen/templates/ComplexCopier.java b/vector/src/main/codegen/templates/ComplexCopier.java index 4df5478f48..6655f6c2a7 100644 --- a/vector/src/main/codegen/templates/ComplexCopier.java +++ b/vector/src/main/codegen/templates/ComplexCopier.java @@ -41,15 +41,8 @@ public class ComplexCopier { * @param input field to read from * @param output field to write to */ - public static void copy(FieldReader input, FieldWriter output) { - writeValue(input, output, null); - } - - public static void copy(FieldReader input, FieldWriter output, ExtensionTypeWriterFactory extensionTypeWriterFactory) { - writeValue(input, output, extensionTypeWriterFactory); - } + public static void copy(FieldReader reader, FieldWriter writer) { - private static void writeValue(FieldReader reader, FieldWriter writer, ExtensionTypeWriterFactory extensionTypeWriterFactory) { final MinorType mt = reader.getMinorType(); switch (mt) { @@ -65,7 +58,7 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension FieldReader childReader = reader.reader(); FieldWriter childWriter = getListWriterForReader(childReader, writer); if (childReader.isSet()) { - writeValue(childReader, childWriter, extensionTypeWriterFactory); + copy(childReader, childWriter); } else { childWriter.writeNull(); } @@ -83,8 +76,8 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension FieldReader structReader = reader.reader(); if (structReader.isSet()) { writer.startEntry(); - writeValue(mapReader.key(), getMapWriterForReader(mapReader.key(), writer.key()), extensionTypeWriterFactory); - writeValue(mapReader.value(), getMapWriterForReader(mapReader.value(), writer.value()), extensionTypeWriterFactory); + copy(mapReader.key(), getMapWriterForReader(mapReader.key(), writer.key())); + copy(mapReader.value(), getMapWriterForReader(mapReader.value(), writer.value())); writer.endEntry(); } else { writer.writeNull(); @@ -103,7 +96,7 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension if (childReader.getMinorType() != Types.MinorType.NULL) { FieldWriter childWriter = getStructWriterForReader(childReader, writer, name); if (childReader.isSet()) { - writeValue(childReader, childWriter, extensionTypeWriterFactory); + copy(childReader, childWriter); } else { childWriter.writeNull(); } @@ -115,14 +108,10 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension } break; case EXTENSIONTYPE: - if (extensionTypeWriterFactory == null) { - throw new IllegalArgumentException("Must provide ExtensionTypeWriterFactory"); - } if (reader.isSet()) { Object value = reader.readObject(); if (value != null) { - writer.addExtensionTypeWriterFactory(extensionTypeWriterFactory); - writer.writeExtension(value); + writer.writeExtension(value, reader.getField().getType()); } } else { writer.writeNull(); diff --git a/vector/src/main/codegen/templates/NullReader.java b/vector/src/main/codegen/templates/NullReader.java index 0529633478..88e6ea98ea 100644 --- a/vector/src/main/codegen/templates/NullReader.java +++ b/vector/src/main/codegen/templates/NullReader.java @@ -86,7 +86,6 @@ public void read(int arrayIndex, Nullable${name}Holder holder){ } - public void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory){} public void read(ExtensionHolder holder) { holder.isSet = 0; } diff --git a/vector/src/main/codegen/templates/PromotableWriter.java b/vector/src/main/codegen/templates/PromotableWriter.java index d22eb00b2c..11d34f72c9 100644 --- a/vector/src/main/codegen/templates/PromotableWriter.java +++ b/vector/src/main/codegen/templates/PromotableWriter.java @@ -286,7 +286,7 @@ protected void setWriter(ValueVector v) { writer = new UnionWriter((UnionVector) vector, nullableStructWriterFactory); break; case EXTENSIONTYPE: - writer = new UnionExtensionWriter((ExtensionTypeVector) vector); + writer = ((ExtensionType) vector.getField().getType()).getNewFieldWriter(vector); break; default: writer = type.getNewFieldWriter(vector); @@ -541,17 +541,13 @@ public void writeLargeVarChar(String value) { } @Override - public void writeExtension(Object value) { - getWriter(MinorType.EXTENSIONTYPE).writeExtension(value); + public void writeExtension(Object value, ArrowType arrowType) { + getWriter(MinorType.EXTENSIONTYPE, arrowType).writeExtension(value, arrowType); } @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory) { - getWriter(MinorType.EXTENSIONTYPE).addExtensionTypeWriterFactory(factory); - } - - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory, ArrowType arrowType) { - getWriter(MinorType.EXTENSIONTYPE, arrowType).addExtensionTypeWriterFactory(factory); + public void write(ExtensionHolder holder) { + getWriter(MinorType.EXTENSIONTYPE, holder.type()).write(holder); } @Override diff --git a/vector/src/main/codegen/templates/UnionListWriter.java b/vector/src/main/codegen/templates/UnionListWriter.java index 3c41ac72b6..4b54739230 100644 --- a/vector/src/main/codegen/templates/UnionListWriter.java +++ b/vector/src/main/codegen/templates/UnionListWriter.java @@ -204,13 +204,13 @@ public MapWriter map(String name, boolean keysSorted) { @Override public ExtensionWriter extension(ArrowType arrowType) { - this.extensionType = arrowType; + extensionType = arrowType; return this; } + @Override public ExtensionWriter extension(String name, ArrowType arrowType) { - ExtensionWriter extensionWriter = writer.extension(name, arrowType); - return extensionWriter; + return writer.extension(name, arrowType); } <#if listName == "LargeList"> @@ -337,13 +337,13 @@ public void writeNull() { @Override public void writeExtension(Object value) { - writer.writeExtension(value); + writer.writeExtension(value, extensionType); writer.setPosition(writer.idx() + 1); } @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory var1) { - writer.addExtensionTypeWriterFactory(var1, extensionType); + public void writeExtension(Object value, ArrowType type) { + writeExtension(value); } public void write(ExtensionHolder var1) { diff --git a/vector/src/main/codegen/templates/UnionReader.java b/vector/src/main/codegen/templates/UnionReader.java index 96ad3e1b9b..0edae7ade0 100644 --- a/vector/src/main/codegen/templates/UnionReader.java +++ b/vector/src/main/codegen/templates/UnionReader.java @@ -79,6 +79,10 @@ public void read(int index, UnionHolder holder) { } private FieldReader getReaderForIndex(int index) { + return getReaderForIndex(index, null); + } + + private FieldReader getReaderForIndex(int index, ArrowType type) { int typeValue = data.getTypeValue(index); FieldReader reader = (FieldReader) readers[typeValue]; if (reader != null) { @@ -105,11 +109,26 @@ private FieldReader getReaderForIndex(int index) { + case EXTENSIONTYPE: + if(type == null) { + throw new RuntimeException("Cannot get Extension reader without an ArrowType"); + } + return (FieldReader) getExtension(type); default: throw new UnsupportedOperationException("Unsupported type: " + MinorType.values()[typeValue]); } } + private ExtensionReader extensionReader; + + private ExtensionReader getExtension(ArrowType type) { + if (extensionReader == null) { + extensionReader = data.getExtension(type).getReader(); + extensionReader.setPosition(idx()); + } + return extensionReader; + } + private SingleStructReaderImpl structReader; private StructReader getStruct() { @@ -240,4 +259,8 @@ public FieldReader reader() { public boolean next() { return getReaderForIndex(idx()).next(); } + + public void read(ExtensionHolder holder){ + getReaderForIndex(idx(), holder.type()).read(holder); + } } diff --git a/vector/src/main/codegen/templates/UnionVector.java b/vector/src/main/codegen/templates/UnionVector.java index 67efdf60f7..c706591966 100644 --- a/vector/src/main/codegen/templates/UnionVector.java +++ b/vector/src/main/codegen/templates/UnionVector.java @@ -379,6 +379,22 @@ public MapVector getMap(String name, ArrowType arrowType) { return mapVector; } + private ExtensionTypeVector extensionVector; + + public ExtensionTypeVector getExtension(ArrowType arrowType) { + if (extensionVector == null) { + int vectorCount = internalStruct.size(); + extensionVector = addOrGet(null, MinorType.EXTENSIONTYPE, arrowType, ExtensionTypeVector.class); + if (internalStruct.size() > vectorCount) { + extensionVector.allocateNew(); + if (callBack != null) { + callBack.doWork(); + } + } + } + return extensionVector; + } + public int getTypeValue(int index) { return typeBuffer.getByte(index * TYPE_WIDTH); } @@ -725,6 +741,8 @@ public ValueVector getVectorByType(int typeId, ArrowType arrowType) { return getListView(); case MAP: return getMap(name, arrowType); + case EXTENSIONTYPE: + return getExtension(arrowType); default: throw new UnsupportedOperationException("Cannot support type: " + MinorType.values()[typeId]); } diff --git a/vector/src/main/codegen/templates/UnionWriter.java b/vector/src/main/codegen/templates/UnionWriter.java index 272edab17c..0db699fd8c 100644 --- a/vector/src/main/codegen/templates/UnionWriter.java +++ b/vector/src/main/codegen/templates/UnionWriter.java @@ -28,6 +28,8 @@ package org.apache.arrow.vector.complex.impl; <#include "/@includes/vv_imports.ftl" /> +import java.util.HashMap; + import org.apache.arrow.vector.complex.writer.BaseWriter; import org.apache.arrow.vector.types.Types.MinorType; @@ -213,8 +215,31 @@ public MapWriter asMap(ArrowType arrowType) { return getMapWriter(arrowType); } + private java.util.Map extensionWriters = new HashMap<>(); + private ExtensionWriter getExtensionWriter(ArrowType arrowType) { - throw new UnsupportedOperationException("ExtensionTypes are not supported yet."); + ExtensionWriter w = extensionWriters.get(arrowType); + if (w == null) { + w = ((ExtensionType) arrowType).getNewFieldWriter(data.getExtension(arrowType)); + w.setPosition(idx()); + extensionWriters.put(arrowType, w); + } + return w; + } + + public void writeExtension(Object value, ArrowType type) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + ExtensionWriter w = getExtensionWriter(type); + w.setPosition(idx()); + w.writeExtension(value); + } + + @Override + public void write(ExtensionHolder holder) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + ExtensionWriter w = getExtensionWriter(holder.type()); + w.setPosition(idx()); + w.write(holder); } BaseWriter getWriter(MinorType minorType) { diff --git a/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java b/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java index cc57cde29e..37dfa20616 100644 --- a/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java @@ -22,7 +22,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.ReferenceManager; import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.util.DataSizeRoundingUtil; import org.apache.arrow.vector.util.TransferPair; @@ -261,18 +260,6 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); } - @Override - public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - - @Override - public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - /** * Transfer the validity buffer from `validityBuffer` to the target vector's `validityBuffer`. * Start at `startIndex` and copy `length` number of elements. If the starting index is 8 byte diff --git a/vector/src/main/java/org/apache/arrow/vector/NullVector.java b/vector/src/main/java/org/apache/arrow/vector/NullVector.java index 0d6dab2837..6bfe540d23 100644 --- a/vector/src/main/java/org/apache/arrow/vector/NullVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/NullVector.java @@ -27,7 +27,6 @@ import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.NullReader; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -330,18 +329,6 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); } - @Override - public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - - @Override - public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - @Override public String getName() { return this.getField().getName(); diff --git a/vector/src/main/java/org/apache/arrow/vector/ValueVector.java b/vector/src/main/java/org/apache/arrow/vector/ValueVector.java index e0628c2ee1..3a5058256c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/ValueVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/ValueVector.java @@ -22,7 +22,6 @@ import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -310,30 +309,6 @@ public interface ValueVector extends Closeable, Iterable { */ void copyFromSafe(int fromIndex, int thisIndex, ValueVector from); - /** - * Copy a cell value from a particular index in source vector to a particular position in this - * vector. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory); - - /** - * Same as {@link #copyFrom(int, int, ValueVector)} except that it handles the case when the - * capacity of the vector needs to be expanded before copy. - * - * @param fromIndex position to copy from in source vector - * @param thisIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory); - /** * Accept a generic {@link VectorVisitor} and return the result. * diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java index 429f9884bb..a6a71cf1a4 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java @@ -21,7 +21,6 @@ import org.apache.arrow.vector.DensityAwareVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeList; @@ -152,18 +151,6 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); } - @Override - public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - - @Override - public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException(); - } - @Override public String getName() { return name; diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java index 48c8127e23..997b5a8b78 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java @@ -49,7 +49,6 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionLargeListReader; import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -483,42 +482,12 @@ public void copyFromSafe(int inIndex, int outIndex, ValueVector from) { */ @Override public void copyFrom(int inIndex, int outIndex, ValueVector from) { - copyFrom(inIndex, outIndex, from, null); - } - - /** - * Copy a cell value from a particular index in source vector to a particular position in this - * vector. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); UnionLargeListWriter out = getWriter(); out.setPosition(outIndex); - ComplexCopier.copy(in, out, writerFactory); - } - - /** - * Same as {@link #copyFrom(int, int, ValueVector)} except that it handles the case when the - * capacity of the vector needs to be expanded before copy. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - copyFrom(inIndex, outIndex, from, writerFactory); + ComplexCopier.copy(in, out); } /** diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 992a664449..2da7eb057e 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -41,7 +41,6 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionLargeListViewReader; import org.apache.arrow.vector.complex.impl.UnionLargeListViewWriter; import org.apache.arrow.vector.complex.impl.UnionListReader; @@ -347,20 +346,6 @@ public void copyFrom(int inIndex, int outIndex, ValueVector from) { "LargeListViewVector does not support copyFrom operation yet."); } - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support copyFromSafe operation yet."); - } - - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support copyFrom operation yet."); - } - @Override public FieldVector getDataVector() { return vector; diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 89549257c4..93a313ef4f 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -42,7 +42,6 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -401,42 +400,12 @@ public void copyFromSafe(int inIndex, int outIndex, ValueVector from) { */ @Override public void copyFrom(int inIndex, int outIndex, ValueVector from) { - copyFrom(inIndex, outIndex, from, null); - } - - /** - * Same as {@link #copyFrom(int, int, ValueVector)} except that it handles the case when the - * capacity of the vector needs to be expanded before copy. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - copyFrom(inIndex, outIndex, from, writerFactory); - } - - /** - * Copy a cell value from a particular index in source vector to a particular position in this - * vector. - * - * @param inIndex position to copy from in source vector - * @param outIndex position to copy to in this vector - * @param from source vector - * @param writerFactory the extension type writer factory to use for copying extension type values - */ - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); FieldWriter out = getWriter(); out.setPosition(outIndex); - ComplexCopier.copy(in, out, writerFactory); + ComplexCopier.copy(in, out); } /** diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java index 2784240429..8711db5e0f 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java @@ -42,7 +42,6 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; import org.apache.arrow.vector.complex.impl.UnionListViewReader; import org.apache.arrow.vector.complex.impl.UnionListViewWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -339,12 +338,6 @@ public void copyFromSafe(int inIndex, int outIndex, ValueVector from) { copyFrom(inIndex, outIndex, from); } - @Override - public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { - copyFrom(inIndex, outIndex, from, writerFactory); - } - @Override public OUT accept(VectorVisitor visitor, IN value) { return visitor.visit(this, value); @@ -352,18 +345,12 @@ public OUT accept(VectorVisitor visitor, IN value) { @Override public void copyFrom(int inIndex, int outIndex, ValueVector from) { - copyFrom(inIndex, outIndex, from, null); - } - - @Override - public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); FieldWriter out = getWriter(); out.setPosition(outIndex); - ComplexCopier.copy(in, out, writerFactory); + ComplexCopier.copy(in, out); } @Override diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java index bf074ecb90..b2e95663f7 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java @@ -115,14 +115,4 @@ public void copyAsValue(ListWriter writer) { public void copyAsValue(MapWriter writer) { ComplexCopier.copy(this, (FieldWriter) writer); } - - @Override - public void copyAsValue(ListWriter writer, ExtensionTypeWriterFactory writerFactory) { - ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); - } - - @Override - public void copyAsValue(MapWriter writer, ExtensionTypeWriterFactory writerFactory) { - ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); - } } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeWriterFactory.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeWriterFactory.java deleted file mode 100644 index a01d591555..0000000000 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeWriterFactory.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.vector.complex.impl; - -import org.apache.arrow.vector.ExtensionTypeVector; -import org.apache.arrow.vector.complex.writer.FieldWriter; - -/** - * A factory interface for creating instances of {@link AbstractExtensionTypeWriter}. This factory - * allows configuring writer implementations for specific {@link ExtensionTypeVector}. - * - * @param the type of writer implementation for a specific {@link ExtensionTypeVector}. - */ -public interface ExtensionTypeWriterFactory { - - /** - * Returns an instance of the writer implementation for the given {@link ExtensionTypeVector}. - * - * @param vector the {@link ExtensionTypeVector} for which the writer implementation is to be - * returned. - * @return an instance of the writer implementation for the given {@link ExtensionTypeVector}. - */ - T getWriterImpl(ExtensionTypeVector vector); -} diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java index 4219069cba..93796aa77e 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java @@ -60,11 +60,6 @@ public void writeExtension(Object var1) { } @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory) { - this.writer = factory.getWriterImpl(vector); - this.writer.setPosition(idx()); - } - public void write(ExtensionHolder holder) { this.writer.write(holder); } @@ -79,6 +74,7 @@ public void setPosition(int index) { @Override public void writeNull() { - this.writer.writeNull(); + this.vector.setNull(getPosition()); + this.vector.setValueCount(getPosition() + 1); } } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java index a9104cb0d2..be236c3166 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java @@ -105,8 +105,4 @@ public boolean next() { public void copyAsValue(UnionLargeListWriter writer) { ComplexCopier.copy(this, (FieldWriter) writer); } - - public void copyAsValue(UnionLargeListWriter writer, ExtensionTypeWriterFactory writerFactory) { - ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); - } } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UuidWriterFactory.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UuidWriterFactory.java deleted file mode 100644 index 35988129cb..0000000000 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UuidWriterFactory.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.vector.complex.impl; - -import org.apache.arrow.vector.ExtensionTypeVector; -import org.apache.arrow.vector.UuidVector; - -/** - * Factory for creating {@link UuidWriterImpl} instances. - * - *

This factory is used to create writers for UUID extension type vectors. - * - * @see UuidWriterImpl - * @see org.apache.arrow.vector.extension.UuidType - */ -public class UuidWriterFactory implements ExtensionTypeWriterFactory { - - /** - * Creates a writer implementation for the given extension type vector. - * - * @param extensionTypeVector the vector to create a writer for - * @return a {@link UuidWriterImpl} if the vector is a {@link UuidVector}, null otherwise - */ - @Override - public AbstractFieldWriter getWriterImpl(ExtensionTypeVector extensionTypeVector) { - if (extensionTypeVector instanceof UuidVector) { - return new UuidWriterImpl((UuidVector) extensionTypeVector); - } - return null; - } -} diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java index 8a78add11c..ee3c79d5e3 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java @@ -21,6 +21,7 @@ import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.holders.NullableUuidHolder; import org.apache.arrow.vector.holders.UuidHolder; +import org.apache.arrow.vector.types.pojo.ArrowType; /** * Writer implementation for {@link UuidVector}. @@ -56,6 +57,11 @@ public void writeExtension(Object value) { vector.setValueCount(getPosition() + 1); } + @Override + public void writeExtension(Object value, ArrowType type) { + writeExtension(value); + } + @Override public void write(ExtensionHolder holder) { if (holder instanceof UuidHolder) { diff --git a/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java b/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java index ca56214fda..780a4ee659 100644 --- a/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java +++ b/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueType.java @@ -54,10 +54,12 @@ import org.apache.arrow.vector.TimeStampNanoVector; import org.apache.arrow.vector.TimeStampSecTZVector; import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.ViewVarBinaryVector; import org.apache.arrow.vector.ViewVarCharVector; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry; @@ -177,6 +179,11 @@ public int hashCode() { return Objects.hash(super.hashCode(), storageType, typeName, vendorName); } + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + throw new UnsupportedOperationException("WriterImpl not yet implemented."); + } + @Override public String toString() { return "OpaqueType(" diff --git a/vector/src/main/java/org/apache/arrow/vector/extension/UuidType.java b/vector/src/main/java/org/apache/arrow/vector/extension/UuidType.java index cd29f930e1..c249c6eda9 100644 --- a/vector/src/main/java/org/apache/arrow/vector/extension/UuidType.java +++ b/vector/src/main/java/org/apache/arrow/vector/extension/UuidType.java @@ -20,6 +20,9 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.FixedSizeBinaryVector; import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.impl.UuidWriterImpl; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry; @@ -108,4 +111,9 @@ public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocato return new UuidVector( name, fieldType, allocator, new FixedSizeBinaryVector(name, allocator, UUID_BYTE_WIDTH)); } + + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + return new UuidWriterImpl((UuidVector) vector); + } } diff --git a/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java b/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java index fc7ed85878..4d3f767aef 100644 --- a/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java +++ b/vector/src/main/java/org/apache/arrow/vector/holders/ExtensionHolder.java @@ -16,7 +16,11 @@ */ package org.apache.arrow.vector.holders; +import org.apache.arrow.vector.types.pojo.ArrowType; + /** Base {@link ValueHolder} class for a {@link org.apache.arrow.vector.ExtensionTypeVector}. */ public abstract class ExtensionHolder implements ValueHolder { public int isSet; + + public abstract ArrowType type(); } diff --git a/vector/src/main/java/org/apache/arrow/vector/holders/NullableUuidHolder.java b/vector/src/main/java/org/apache/arrow/vector/holders/NullableUuidHolder.java index e5398d82cf..7fa50ca761 100644 --- a/vector/src/main/java/org/apache/arrow/vector/holders/NullableUuidHolder.java +++ b/vector/src/main/java/org/apache/arrow/vector/holders/NullableUuidHolder.java @@ -17,6 +17,8 @@ package org.apache.arrow.vector.holders; import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.extension.UuidType; +import org.apache.arrow.vector.types.pojo.ArrowType; /** * Value holder for nullable UUID values. @@ -32,4 +34,9 @@ public class NullableUuidHolder extends ExtensionHolder { /** Buffer containing 16-byte UUID data. */ public ArrowBuf buffer; + + @Override + public ArrowType type() { + return UuidType.INSTANCE; + } } diff --git a/vector/src/main/java/org/apache/arrow/vector/holders/UuidHolder.java b/vector/src/main/java/org/apache/arrow/vector/holders/UuidHolder.java index 484e05c24b..8a0a66e435 100644 --- a/vector/src/main/java/org/apache/arrow/vector/holders/UuidHolder.java +++ b/vector/src/main/java/org/apache/arrow/vector/holders/UuidHolder.java @@ -17,6 +17,8 @@ package org.apache.arrow.vector.holders; import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.extension.UuidType; +import org.apache.arrow.vector.types.pojo.ArrowType; /** * Value holder for non-nullable UUID values. @@ -35,4 +37,9 @@ public class UuidHolder extends ExtensionHolder { public UuidHolder() { this.isSet = 1; } + + @Override + public ArrowType type() { + return UuidType.INSTANCE; + } } diff --git a/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java b/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java index d5cbf925b2..759c84651d 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java @@ -26,18 +26,24 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.UUID; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; import org.apache.arrow.vector.complex.LargeListVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionLargeListReader; import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; +import org.apache.arrow.vector.extension.UuidType; +import org.apache.arrow.vector.holders.UuidHolder; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.TransferPair; +import org.apache.arrow.vector.util.UuidUtility; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -1021,6 +1027,79 @@ public void testGetTransferPairWithField() throws Exception { } } + @Test + public void testCopyValueSafeForExtensionType() throws Exception { + try (LargeListVector inVector = LargeListVector.empty("input", allocator); + LargeListVector outVector = LargeListVector.empty("output", allocator)) { + UnionLargeListWriter writer = inVector.getWriter(); + writer.allocate(); + + // Create first list with UUIDs + writer.setPosition(0); + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + writer.startList(); + ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); + extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u2); + writer.endList(); + + // Create second list with UUIDs + writer.setPosition(1); + UUID u3 = UUID.randomUUID(); + UUID u4 = UUID.randomUUID(); + writer.startList(); + extensionWriter = writer.extension(UuidType.INSTANCE); + extensionWriter.writeExtension(u3); + extensionWriter.writeExtension(u4); + extensionWriter.writeNull(); + + writer.endList(); + writer.setValueCount(2); + + // Use copyFromSafe with ExtensionTypeWriterFactory + // This internally calls TransferImpl.copyValueSafe with ExtensionTypeWriterFactory + outVector.allocateNew(); + TransferPair tp = inVector.makeTransferPair(outVector); + tp.copyValueSafe(0, 0); + tp.copyValueSafe(1, 1); + outVector.setValueCount(2); + + // Verify first list + UnionLargeListReader reader = outVector.getReader(); + reader.setPosition(0); + assertTrue(reader.isSet(), "first list shouldn't be null"); + reader.next(); + FieldReader uuidReader = reader.reader(); + UuidHolder holder = new UuidHolder(); + uuidReader.read(holder); + UUID actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u1, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u2, actualUuid); + + // Verify second list + reader.setPosition(1); + assertTrue(reader.isSet(), "second list shouldn't be null"); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u3, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u4, actualUuid); + reader.next(); + uuidReader = reader.reader(); + assertFalse(uuidReader.isSet(), "third element should be null"); + } + } + private void writeIntValues(UnionLargeListWriter writer, int[] values) { writer.startList(); for (int v : values) { diff --git a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index df3a609f53..e96ac3027c 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -35,7 +35,6 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.extension.UuidType; @@ -1217,7 +1216,6 @@ public void testListVectorWithExtensionType() throws Exception { UUID u2 = UUID.randomUUID(); writer.startList(); ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); extensionWriter.writeExtension(u1); extensionWriter.writeExtension(u2); writer.endList(); @@ -1245,7 +1243,6 @@ public void testListVectorReaderForExtensionType() throws Exception { UUID u2 = UUID.randomUUID(); writer.startList(); ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); extensionWriter.writeExtension(u1); extensionWriter.writeExtension(u2); writer.endList(); @@ -1279,23 +1276,78 @@ public void testCopyFromForExtensionType() throws Exception { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); writer.startList(); + + writer.extension(UuidType.INSTANCE).writeExtension(u1); + writer.writeExtension(u2); + writer.writeNull(); + writer.endList(); + + writer.setValueCount(3); + + // copy values from input to output + outVector.allocateNew(); + outVector.copyFrom(0, 0, inVector); + outVector.setValueCount(3); + + UnionListReader reader = outVector.getReader(); + assertTrue(reader.isSet(), "shouldn't be null"); + reader.setPosition(0); + reader.next(); + FieldReader uuidReader = reader.reader(); + UuidHolder holder = new UuidHolder(); + uuidReader.read(holder); + UUID actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u1, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u2, actualUuid); + } + } + + @Test + public void testCopyValueSafeForExtensionType() throws Exception { + try (ListVector inVector = ListVector.empty("input", allocator); + ListVector outVector = ListVector.empty("output", allocator)) { + UnionListWriter writer = inVector.getWriter(); + writer.allocate(); + + // Create first list with UUIDs + writer.setPosition(0); + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + writer.startList(); ExtensionWriter extensionWriter = writer.extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); extensionWriter.writeExtension(u1); extensionWriter.writeExtension(u2); - extensionWriter.writeNull(); writer.endList(); - writer.setValueCount(1); + // Create second list with UUIDs + writer.setPosition(1); + UUID u3 = UUID.randomUUID(); + UUID u4 = UUID.randomUUID(); + writer.startList(); + extensionWriter = writer.extension(UuidType.INSTANCE); + extensionWriter.writeExtension(u3); + extensionWriter.writeExtension(u4); + extensionWriter.writeNull(); - // copy values from input to output + writer.endList(); + writer.setValueCount(2); + + // Use TransferPair with ExtensionTypeWriterFactory + // This tests the new makeTransferPair API with writerFactory parameter outVector.allocateNew(); - outVector.copyFrom(0, 0, inVector, new UuidWriterFactory()); - outVector.setValueCount(1); + TransferPair transferPair = inVector.makeTransferPair(outVector); + transferPair.copyValueSafe(0, 0); + transferPair.copyValueSafe(1, 1); + outVector.setValueCount(2); + // Verify first list UnionListReader reader = outVector.getReader(); - assertTrue(reader.isSet(), "shouldn't be null"); reader.setPosition(0); + assertTrue(reader.isSet(), "first list shouldn't be null"); reader.next(); FieldReader uuidReader = reader.reader(); UuidHolder holder = new UuidHolder(); @@ -1307,6 +1359,23 @@ public void testCopyFromForExtensionType() throws Exception { uuidReader.read(holder); actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); assertEquals(u2, actualUuid); + + // Verify second list + reader.setPosition(1); + assertTrue(reader.isSet(), "second list shouldn't be null"); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u3, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + actualUuid = UuidUtility.uuidFromArrowBuf(holder.buffer, 0); + assertEquals(u4, actualUuid); + reader.next(); + uuidReader = reader.reader(); + assertFalse(uuidReader.isSet(), "third element should be null"); } } diff --git a/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java b/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java index d9d2ca50dc..bfac1237a4 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java @@ -35,7 +35,6 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionMapWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.ListWriter; @@ -1285,14 +1284,12 @@ public void testMapVectorWithExtensionType() throws Exception { writer.startEntry(); writer.key().bigInt().writeBigInt(0); ExtensionWriter extensionWriter = writer.value().extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u1, UuidType.INSTANCE); writer.endEntry(); writer.startEntry(); writer.key().bigInt().writeBigInt(1); extensionWriter = writer.value().extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u2); + extensionWriter.writeExtension(u2, UuidType.INSTANCE); writer.endEntry(); writer.endMap(); @@ -1327,20 +1324,17 @@ public void testCopyFromForExtensionType() throws Exception { writer.startEntry(); writer.key().bigInt().writeBigInt(0); ExtensionWriter extensionWriter = writer.value().extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u1, UuidType.INSTANCE); writer.endEntry(); writer.startEntry(); writer.key().bigInt().writeBigInt(1); - extensionWriter = writer.value().extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u2); + extensionWriter.writeExtension(u2, UuidType.INSTANCE); writer.endEntry(); writer.endMap(); writer.setValueCount(1); outVector.allocateNew(); - outVector.copyFrom(0, 0, inVector, new UuidWriterFactory()); + outVector.copyFrom(0, 0, inVector); outVector.setValueCount(1); UnionMapReader mapReader = outVector.getReader(); @@ -1576,4 +1570,103 @@ public void testFixedSizeBinaryFirstInitialization() { assertArrayEquals(new byte[] {32, 21}, (byte[]) resultStruct.get(MapVector.VALUE_NAME)); } } + + @Test + public void testMapWithUuidKeyAndListUuidValue() throws Exception { + try (final MapVector mapVector = MapVector.empty("map", allocator, false)) { + mapVector.allocateNew(); + UnionMapWriter writer = mapVector.getWriter(); + + // Create test UUIDs + UUID key1 = UUID.randomUUID(); + UUID key2 = UUID.randomUUID(); + UUID value1a = UUID.randomUUID(); + UUID value1b = UUID.randomUUID(); + UUID value2a = UUID.randomUUID(); + UUID value2b = UUID.randomUUID(); + UUID value2c = UUID.randomUUID(); + + // Write first map entry: {key1 -> [value1a, value1b]} + writer.setPosition(0); + writer.startMap(); + + writer.startEntry(); + ExtensionWriter keyWriter = writer.key().extension(UuidType.INSTANCE); + keyWriter.writeExtension(key1, UuidType.INSTANCE); + ListWriter valueWriter = writer.value().list(); + valueWriter.startList(); + ExtensionWriter listItemWriter = valueWriter.extension(UuidType.INSTANCE); + listItemWriter.writeExtension(value1a, UuidType.INSTANCE); + listItemWriter = valueWriter.extension(UuidType.INSTANCE); + listItemWriter.writeExtension(value1b, UuidType.INSTANCE); + valueWriter.endList(); + writer.endEntry(); + + writer.startEntry(); + keyWriter = writer.key().extension(UuidType.INSTANCE); + keyWriter.writeExtension(key2, UuidType.INSTANCE); + valueWriter = writer.value().list(); + valueWriter.startList(); + listItemWriter = valueWriter.extension(UuidType.INSTANCE); + listItemWriter.writeExtension(value2a, UuidType.INSTANCE); + listItemWriter = valueWriter.extension(UuidType.INSTANCE); + listItemWriter.writeExtension(value2b, UuidType.INSTANCE); + listItemWriter = valueWriter.extension(UuidType.INSTANCE); + listItemWriter.writeExtension(value2c, UuidType.INSTANCE); + valueWriter.endList(); + writer.endEntry(); + + writer.endMap(); + writer.setValueCount(1); + + // Read and verify the data + UnionMapReader mapReader = mapVector.getReader(); + mapReader.setPosition(0); + + // Read first entry + mapReader.next(); + FieldReader keyReader = mapReader.key(); + UuidHolder keyHolder = new UuidHolder(); + keyReader.read(keyHolder); + UUID actualKey = UuidUtility.uuidFromArrowBuf(keyHolder.buffer, 0); + assertEquals(key1, actualKey); + + FieldReader valueReader = mapReader.value(); + assertTrue(valueReader.isSet()); + List listValue = (List) valueReader.readObject(); + assertEquals(2, listValue.size()); + + // Verify first list item - readObject() returns UUID objects for extension types + UUID actualValue1a = (UUID) listValue.get(0); + assertEquals(value1a, actualValue1a); + + // Verify second list item + UUID actualValue1b = (UUID) listValue.get(1); + assertEquals(value1b, actualValue1b); + + // Read second entry + mapReader.next(); + keyReader = mapReader.key(); + keyReader.read(keyHolder); + actualKey = UuidUtility.uuidFromArrowBuf(keyHolder.buffer, 0); + assertEquals(key2, actualKey); + + valueReader = mapReader.value(); + assertTrue(valueReader.isSet()); + listValue = (List) valueReader.readObject(); + assertEquals(3, listValue.size()); + + // Verify first list item - readObject() returns UUID objects for extension types + UUID actualValue2a = (UUID) listValue.get(0); + assertEquals(value2a, actualValue2a); + + // Verify second list item + UUID actualValue2b = (UUID) listValue.get(1); + assertEquals(value2b, actualValue2b); + + // Verify third list item + UUID actualValue2c = (UUID) listValue.get(2); + assertEquals(value2c, actualValue2c); + } + } } diff --git a/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java b/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java index 21ebeebc86..8c8a45f588 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java @@ -160,17 +160,23 @@ public void testGetPrimitiveVectors() { UnionVector unionVector = vector.addOrGetUnion("union"); unionVector.addVector(new BigIntVector("bigInt", allocator)); unionVector.addVector(new SmallIntVector("smallInt", allocator)); + unionVector.addVector(new UuidVector("uuid", allocator)); // add varchar vector vector.addOrGet( "varchar", FieldType.nullable(MinorType.VARCHAR.getType()), VarCharVector.class); + // add extension vector + vector.addOrGet("extension", FieldType.nullable(UuidType.INSTANCE), UuidVector.class); + List primitiveVectors = vector.getPrimitiveVectors(); - assertEquals(4, primitiveVectors.size()); + assertEquals(6, primitiveVectors.size()); assertEquals(MinorType.INT, primitiveVectors.get(0).getMinorType()); assertEquals(MinorType.BIGINT, primitiveVectors.get(1).getMinorType()); assertEquals(MinorType.SMALLINT, primitiveVectors.get(2).getMinorType()); - assertEquals(MinorType.VARCHAR, primitiveVectors.get(3).getMinorType()); + assertEquals(MinorType.EXTENSIONTYPE, primitiveVectors.get(3).getMinorType()); + assertEquals(MinorType.VARCHAR, primitiveVectors.get(4).getMinorType()); + assertEquals(MinorType.EXTENSIONTYPE, primitiveVectors.get(5).getMinorType()); } } diff --git a/vector/src/test/java/org/apache/arrow/vector/TestUuidVector.java b/vector/src/test/java/org/apache/arrow/vector/TestUuidVector.java index 3d70238ece..a3690461cf 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestUuidVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestUuidVector.java @@ -33,6 +33,7 @@ import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.holders.NullableUuidHolder; import org.apache.arrow.vector.holders.UuidHolder; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.util.UuidUtility; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -358,7 +359,13 @@ void testReaderReadWithUnsupportedHolder() throws Exception { reader.setPosition(0); // Create a mock unsupported holder - ExtensionHolder unsupportedHolder = new ExtensionHolder() {}; + ExtensionHolder unsupportedHolder = + new ExtensionHolder() { + @Override + public ArrowType type() { + return null; + } + }; IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> reader.read(unsupportedHolder)); @@ -377,7 +384,13 @@ void testReaderReadWithArrayIndexUnsupportedHolder() throws Exception { UuidReaderImpl reader = (UuidReaderImpl) vector.getReader(); // Create a mock unsupported holder - ExtensionHolder unsupportedHolder = new ExtensionHolder() {}; + ExtensionHolder unsupportedHolder = + new ExtensionHolder() { + @Override + public ArrowType type() { + return null; + } + }; IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> reader.read(0, unsupportedHolder)); diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java index 73c1cd3b74..b2a8cf9ba4 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java @@ -861,7 +861,6 @@ public void testCopyListVectorWithExtensionType() { listWriter.setPosition(i); listWriter.startList(); ExtensionWriter extensionWriter = listWriter.extension(UuidType.INSTANCE); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); extensionWriter.writeExtension(UUID.randomUUID()); extensionWriter.writeExtension(UUID.randomUUID()); listWriter.endList(); @@ -874,7 +873,7 @@ public void testCopyListVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out); } to.setValueCount(COUNT); @@ -897,11 +896,9 @@ public void testCopyMapVectorWithExtensionType() { mapWriter.startMap(); mapWriter.startEntry(); ExtensionWriter extensionKeyWriter = mapWriter.key().extension(UuidType.INSTANCE); - extensionKeyWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionKeyWriter.writeExtension(UUID.randomUUID()); + extensionKeyWriter.writeExtension(UUID.randomUUID(), UuidType.INSTANCE); ExtensionWriter extensionValueWriter = mapWriter.value().extension(UuidType.INSTANCE); - extensionValueWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionValueWriter.writeExtension(UUID.randomUUID()); + extensionValueWriter.writeExtension(UUID.randomUUID(), UuidType.INSTANCE); mapWriter.endEntry(); mapWriter.endMap(); } @@ -914,7 +911,7 @@ public void testCopyMapVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out); } to.setValueCount(COUNT); @@ -934,12 +931,10 @@ public void testCopyStructVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { structWriter.setPosition(i); structWriter.start(); - ExtensionWriter extensionWriter1 = structWriter.extension("timestamp1", UuidType.INSTANCE); - extensionWriter1.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter1.writeExtension(UUID.randomUUID()); - ExtensionWriter extensionWriter2 = structWriter.extension("timestamp2", UuidType.INSTANCE); - extensionWriter2.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter2.writeExtension(UUID.randomUUID()); + ExtensionWriter extensionWriter1 = structWriter.extension("uuid1", UuidType.INSTANCE); + extensionWriter1.writeExtension(UUID.randomUUID(), UuidType.INSTANCE); + ExtensionWriter extensionWriter2 = structWriter.extension("uuid2", UuidType.INSTANCE); + extensionWriter2.writeExtension(UUID.randomUUID(), UuidType.INSTANCE); structWriter.end(); } @@ -951,7 +946,7 @@ public void testCopyStructVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out); } to.setValueCount(COUNT); diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index c71717a027..5b6d65d6ba 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -31,6 +31,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DirtyRootAllocator; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.LargeVarBinaryVector; import org.apache.arrow.vector.LargeVarCharVector; import org.apache.arrow.vector.UuidVector; @@ -49,6 +50,7 @@ import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; import org.apache.arrow.vector.holders.UnionHolder; +import org.apache.arrow.vector.holders.UuidHolder; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -57,6 +59,7 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.Text; +import org.apache.arrow.vector.util.UuidUtility; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -100,7 +103,6 @@ public void testPromoteToUnion() throws Exception { writer.integer("A").writeInt(10); // we don't write anything in 3 - writer.setPosition(4); writer.integer("A").writeInt(100); @@ -130,9 +132,21 @@ public void testPromoteToUnion() throws Exception { binHolder.buffer = buf; writer.fixedSizeBinary("A", 4).write(binHolder); + writer.setPosition(9); + UUID uuid = UUID.randomUUID(); + writer.extension("A", UuidType.INSTANCE).writeExtension(uuid, UuidType.INSTANCE); + writer.end(); + + writer.setPosition(10); + UUID uuid2 = UUID.randomUUID(); + UuidHolder uuidHolder = new UuidHolder(); + uuidHolder.buffer = allocator.buffer(UuidType.UUID_BYTE_WIDTH); + uuidHolder.buffer.setBytes(0, UuidUtility.getBytesFromUUID(uuid2)); + writer.extension("A", UuidType.INSTANCE).write(uuidHolder); writer.end(); + allocator.releaseBytes(UuidType.UUID_BYTE_WIDTH); - container.setValueCount(9); + container.setValueCount(11); final UnionVector uv = v.getChild("A", UnionVector.class); @@ -169,6 +183,12 @@ public void testPromoteToUnion() throws Exception { .order(ByteOrder.nativeOrder()) .getInt()); + assertFalse(uv.isNull(9), "9 shouldn't be null"); + assertEquals(uuid, uv.getObject(9)); + + assertFalse(uv.isNull(10), "10 shouldn't be null"); + assertEquals(uuid2, uv.getObject(10)); + container.clear(); container.allocateNew(); @@ -791,12 +811,11 @@ public void testExtensionType() throws Exception { UUID u2 = UUID.randomUUID(); container.allocateNew(); container.setValueCount(1); - writer.addExtensionTypeWriterFactory(new UuidWriterFactory()); writer.setPosition(0); - writer.writeExtension(u1); + writer.writeExtension(u1, UuidType.INSTANCE); writer.setPosition(1); - writer.writeExtension(u2); + writer.writeExtension(u2, UuidType.INSTANCE); container.setValueCount(2); @@ -817,16 +836,15 @@ public void testExtensionTypeForList() throws Exception { UUID u2 = UUID.randomUUID(); container.allocateNew(); container.setValueCount(1); - writer.addExtensionTypeWriterFactory(new UuidWriterFactory()); writer.setPosition(0); - writer.writeExtension(u1); + writer.writeExtension(u1, UuidType.INSTANCE); writer.setPosition(1); - writer.writeExtension(u2); + writer.writeExtension(u2, UuidType.INSTANCE); container.setValueCount(2); - UuidVector uuidVector = (UuidVector) container.getDataVector(); + FieldVector uuidVector = container.getDataVector(); assertEquals(u1, uuidVector.getObject(0)); assertEquals(u2, uuidVector.getObject(1)); } diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index 3a8f3f8e6a..b131bf07e2 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -66,7 +66,6 @@ import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionReader; import org.apache.arrow.vector.complex.impl.UnionWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.BaseReader.StructReader; import org.apache.arrow.vector.complex.reader.BigIntReader; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -87,6 +86,7 @@ import org.apache.arrow.vector.holders.NullableFixedSizeBinaryHolder; import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; import org.apache.arrow.vector.holders.NullableTimeStampNanoTZHolder; +import org.apache.arrow.vector.holders.NullableUuidHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; import org.apache.arrow.vector.holders.UuidHolder; import org.apache.arrow.vector.types.TimeUnit; @@ -1106,6 +1106,13 @@ public void simpleUnion() throws Exception { new UnionVector("union", allocator, /* field type */ null, /* call-back */ null); UnionWriter unionWriter = new UnionWriter(vector); unionWriter.allocate(); + + UUID uuid = UUID.randomUUID(); + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + byte[] uuidByte = bb.array(); + for (int i = 0; i < COUNT; i++) { unionWriter.setPosition(i); if (i % 5 == 0) { @@ -1128,6 +1135,12 @@ public void simpleUnion() throws Exception { holder.buffer = buf; unionWriter.write(holder); bufs.add(buf); + } else if (i % 5 == 4) { + UuidHolder holder = new UuidHolder(); + holder.buffer = allocator.buffer(UuidType.UUID_BYTE_WIDTH); + holder.buffer.setBytes(0, uuidByte); + unionWriter.write(holder); + allocator.releaseBytes(UuidType.UUID_BYTE_WIDTH); } else { unionWriter.writeFloat4((float) i); } @@ -1153,6 +1166,10 @@ public void simpleUnion() throws Exception { unionReader.read(holder); assertEquals(i, holder.buffer.getInt(0)); assertEquals(4, holder.byteWidth); + } else if (i % 5 == 4) { + NullableUuidHolder holder = new NullableUuidHolder(); + unionReader.read(holder); + assertEquals(UuidUtility.uuidFromArrowBuf(holder.buffer, 0), uuid); } else { assertEquals((float) i, unionReader.readFloat(), 1e-12); } @@ -2512,8 +2529,7 @@ public void extensionWriterReader() throws Exception { { ExtensionWriter extensionWriter = rootWriter.extension("uuid1", UuidType.INSTANCE); extensionWriter.setPosition(0); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u1, UuidType.INSTANCE); } // read StructReader rootReader = new SingleStructReaderImpl(parent).reader("root"); diff --git a/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java index 2ac4045aa2..ae5ac0726c 100644 --- a/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java +++ b/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -44,10 +44,12 @@ import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.UuidVector; import org.apache.arrow.vector.ValueIterableVector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compare.Range; import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.extension.UuidType; import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowFileWriter; @@ -333,6 +335,11 @@ public String serialize() { public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator) { return new LocationVector(name, allocator); } + + @Override + public FieldWriter getNewFieldWriter(ValueVector vector) { + throw new UnsupportedOperationException("Not yet implemented."); + } } public static class LocationVector extends ExtensionTypeVector