Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions python/sparknlp/reader/reader2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ class Reader2Image(
typeConverter=TypeConverters.toString
)

useEncodedImageBytes = Param(
Params._dummy(),
"useEncodedImageBytes",
"If true, use the original encoded image bytes (e.g., JPEG, PNG). "
"If false, decode the image into raw pixel data.",
typeConverter=TypeConverters.toBoolean
)

outputPromptColumn = Param(
Params._dummy(),
"outputPromptColumn",
"If true, outputs an additional 'prompt' column containing "
"the text prompt as a Spark NLP Annotation.",
typeConverter=TypeConverters.toBoolean
)

@keyword_only
def __init__(self):
super(Reader2Image, self).__init__(classname="com.johnsnowlabs.reader.Reader2Image")
Expand All @@ -97,7 +113,9 @@ def __init__(self):
promptTemplate="qwen2vl-chat",
readAsImage=True,
customPromptTemplate="",
ignoreExceptions=True
ignoreExceptions=True,
useEncodedImageBytes=False,
outputPromptColumn=False
)

@keyword_only
Expand Down Expand Up @@ -133,4 +151,27 @@ def setCustomPromptTemplate(self, value: str):
value : str
Custom prompt template string.
"""
return self._set(customPromptTemplate=value)
return self._set(customPromptTemplate=value)

def setUseEncodedImageBytes(self, value: bool):
"""Sets whether to use encoded image bytes or decoded pixels.

Parameters
----------
value : bool
If True, keeps the image bytes in their encoded (compressed) form.
If False, decodes the image into a pixel matrix representation.
"""
return self._set(useEncodedImageBytes=value)


def setOutputPromptColumn(self, value: bool):
"""Enables or disables creation of a prompt column.

Parameters
----------
value : bool
If True, adds an additional 'prompt' column to the output DataFrame
containing the text prompt as a Spark NLP Annotation.
"""
return self._set(outputPromptColumn=value)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.johnsnowlabs.partition.util

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.{DataFrame, SparkSession}

import java.nio.charset.Charset
import java.nio.file.Files
Expand Down
96 changes: 81 additions & 15 deletions src/main/scala/com/johnsnowlabs/reader/Reader2Image.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@ package com.johnsnowlabs.reader
import com.johnsnowlabs.nlp.AnnotatorType.IMAGE
import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.{AnnotationImage, HasOutputAnnotationCol, HasOutputAnnotatorType}
import com.johnsnowlabs.partition.util.PartitionHelper.{
datasetWithBinaryFile,
datasetWithTextFile,
isStringContent
}
import com.johnsnowlabs.nlp.{Annotation, AnnotationImage, AnnotatorType, HasOutputAnnotationCol, HasOutputAnnotatorType}
import com.johnsnowlabs.partition.util.PartitionHelper.{datasetWithBinaryFile, datasetWithTextFile, isStringContent}
import com.johnsnowlabs.partition.{HasBinaryReaderProperties, Partition}
import com.johnsnowlabs.reader.util.{ImageParser, ImagePromptTemplate}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -113,6 +109,21 @@ class Reader2Image(override val uid: String)

def setCustomPromptTemplate(value: String): this.type = set(promptTemplate, value)

val useEncodedImageBytes: Param[Boolean] =
new Param[Boolean](
this,
"useEncodedImageBytes",
"If true, use the original encoded image bytes (e.g. JPEG, PNG). " +
"If false, decode the image into pixel data.")

def setUseEncodedImageBytes(value: Boolean): this.type = set(useEncodedImageBytes, value)

val outputPromptColumn: BooleanParam =
new BooleanParam(this, "outputPromptColumn",
"If true, outputs an additional 'prompt' column containing the text prompt as a Spark NLP Annotation.")

def setOutputPromptColumn(value: Boolean): this.type = set(outputPromptColumn, value)

setDefault(
contentType -> "",
outputFormat -> "image",
Expand All @@ -121,7 +132,10 @@ class Reader2Image(override val uid: String)
promptTemplate -> "qwen2vl-chat",
readAsImage -> true,
customPromptTemplate -> "",
ignoreExceptions -> true)
ignoreExceptions -> true,
useEncodedImageBytes -> false,
outputPromptColumn -> false
)

override def transform(dataset: Dataset[_]): DataFrame = {
validateRequiredParameters()
Expand All @@ -133,7 +147,8 @@ class Reader2Image(override val uid: String)
} else {
partitionContent(partition, $(contentPath), isStringContent($(contentType)), dataset)
}
if (!structuredDf.isEmpty) {

val resultDf = if (!structuredDf.isEmpty) {
val annotatedDf = structuredDf
.withColumn(
getOutputCol,
Expand All @@ -143,6 +158,14 @@ class Reader2Image(override val uid: String)
} else {
structuredDf
}

if ($(outputPromptColumn)) {
resultDf.withColumn(
"prompt",
wrapPromptColumnMetadata(buildPromptAnnotationUdf(element_at(col(getOutputCol), 1)("text")))
)
} else resultDf

}

override def partitionContent(
Expand Down Expand Up @@ -293,6 +316,12 @@ class Reader2Image(override val uid: String)
val imageFields = ImageIOUtils.bufferedImageToImageFields(decodedContent, origin)

if (imageFields.isDefined) {
val resultBytes: Array[Byte] = if ($(useEncodedImageBytes)) {
binaryContentOpt.getOrElse(Array.emptyByteArray)
} else {
imageFields.get.data
}

Some(
AnnotationImage(
IMAGE,
Expand All @@ -301,7 +330,7 @@ class Reader2Image(override val uid: String)
imageFields.get.width,
imageFields.get.nChannels,
imageFields.get.mode,
imageFields.get.data,
resultBytes,
metadata,
buildPrompt))
} else {
Expand Down Expand Up @@ -332,6 +361,18 @@ class Reader2Image(override val uid: String)
}
}

private val buildPromptAnnotationUdf = udf((text: String) => {
if (text == null) Seq.empty[Annotation]
else Seq(Annotation(
annotatorType = "document",
begin = 0,
end = text.length - 1,
result = text,
metadata = Map("source" -> "Reader2Image"),
embeddings = Array.emptyFloatArray
))
})

def afterAnnotate(dataset: DataFrame): DataFrame = {
if ($(explodeDocs)) {
dataset
Expand All @@ -348,15 +389,30 @@ class Reader2Image(override val uid: String)
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

override def transformSchema(schema: StructType): StructType = {
val metadataBuilder: MetadataBuilder = new MetadataBuilder()
metadataBuilder.putString("annotatorType", outputAnnotatorType)
val outputFields = schema.fields :+
val imageMetadataBuilder: MetadataBuilder = new MetadataBuilder()
imageMetadataBuilder.putString("annotatorType", outputAnnotatorType)

val baseStruct = schema.add(
StructField(
getOutputCol,
ArrayType(AnnotationImage.dataType),
nullable = false,
metadataBuilder.build)
StructType(outputFields)
imageMetadataBuilder.build)
)

if ($(outputPromptColumn)) {
val promptMetadataBuilder = new MetadataBuilder()
promptMetadataBuilder.putString("annotatorType", AnnotatorType.DOCUMENT)

baseStruct.add(
StructField(
"prompt",
ArrayType(Annotation.dataType),
nullable = true,
promptMetadataBuilder.build
)
)
} else baseStruct
}

override val outputAnnotatorType: AnnotatorType = IMAGE
Expand All @@ -371,6 +427,16 @@ class Reader2Image(override val uid: String)
col.as(getOutputCol, columnMetadata)
}

private lazy val promptColumnMetadata: Metadata = {
val metadataBuilder = new MetadataBuilder()
metadataBuilder.putString("annotatorType", AnnotatorType.DOCUMENT)
metadataBuilder.build
}

private def wrapPromptColumnMetadata(col: Column): Column = {
col.as("prompt", promptColumnMetadata)
}

protected def validateRequiredParameters(): Unit = {
require(
$(contentPath) != null && $(contentPath).trim.nonEmpty,
Expand Down
Loading
Loading