diff --git a/python/sparknlp/reader/reader2image.py b/python/sparknlp/reader/reader2image.py index 61fc64387c39aa..fb72e750e06061 100644 --- a/python/sparknlp/reader/reader2image.py +++ b/python/sparknlp/reader/reader2image.py @@ -86,6 +86,22 @@ class Reader2Image( typeConverter=TypeConverters.toString ) + useEncodedImageBytes = Param( + Params._dummy(), + "useEncodedImageBytes", + "If true, use the original encoded image bytes (e.g., JPEG, PNG). " + "If false, decode the image into raw pixel data.", + typeConverter=TypeConverters.toBoolean + ) + + outputPromptColumn = Param( + Params._dummy(), + "outputPromptColumn", + "If true, outputs an additional 'prompt' column containing " + "the text prompt as a Spark NLP Annotation.", + typeConverter=TypeConverters.toBoolean + ) + @keyword_only def __init__(self): super(Reader2Image, self).__init__(classname="com.johnsnowlabs.reader.Reader2Image") @@ -97,7 +113,9 @@ def __init__(self): promptTemplate="qwen2vl-chat", readAsImage=True, customPromptTemplate="", - ignoreExceptions=True + ignoreExceptions=True, + useEncodedImageBytes=False, + outputPromptColumn=False ) @keyword_only @@ -133,4 +151,27 @@ def setCustomPromptTemplate(self, value: str): value : str Custom prompt template string. """ - return self._set(customPromptTemplate=value) \ No newline at end of file + return self._set(customPromptTemplate=value) + + def setUseEncodedImageBytes(self, value: bool): + """Sets whether to use encoded image bytes or decoded pixels. + + Parameters + ---------- + value : bool + If True, keeps the image bytes in their encoded (compressed) form. + If False, decodes the image into a pixel matrix representation. + """ + return self._set(useEncodedImageBytes=value) + + + def setOutputPromptColumn(self, value: bool): + """Enables or disables creation of a prompt column. + + Parameters + ---------- + value : bool + If True, adds an additional 'prompt' column to the output DataFrame + containing the text prompt as a Spark NLP Annotation. + """ + return self._set(outputPromptColumn=value) \ No newline at end of file diff --git a/src/main/scala/com/johnsnowlabs/partition/util/PartitionHelper.scala b/src/main/scala/com/johnsnowlabs/partition/util/PartitionHelper.scala index efb8badecc73e3..6f3604c8a1cef3 100644 --- a/src/main/scala/com/johnsnowlabs/partition/util/PartitionHelper.scala +++ b/src/main/scala/com/johnsnowlabs/partition/util/PartitionHelper.scala @@ -15,7 +15,7 @@ */ package com.johnsnowlabs.partition.util -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.{DataFrame, SparkSession} import java.nio.charset.Charset import java.nio.file.Files diff --git a/src/main/scala/com/johnsnowlabs/reader/Reader2Image.scala b/src/main/scala/com/johnsnowlabs/reader/Reader2Image.scala index 1cf431abbdbac9..9562c15b42ed26 100644 --- a/src/main/scala/com/johnsnowlabs/reader/Reader2Image.scala +++ b/src/main/scala/com/johnsnowlabs/reader/Reader2Image.scala @@ -18,16 +18,12 @@ package com.johnsnowlabs.reader import com.johnsnowlabs.nlp.AnnotatorType.IMAGE import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils import com.johnsnowlabs.nlp.util.io.ResourceHelper -import com.johnsnowlabs.nlp.{AnnotationImage, HasOutputAnnotationCol, HasOutputAnnotatorType} -import com.johnsnowlabs.partition.util.PartitionHelper.{ - datasetWithBinaryFile, - datasetWithTextFile, - isStringContent -} +import com.johnsnowlabs.nlp.{Annotation, AnnotationImage, AnnotatorType, HasOutputAnnotationCol, HasOutputAnnotatorType} +import com.johnsnowlabs.partition.util.PartitionHelper.{datasetWithBinaryFile, datasetWithTextFile, isStringContent} import com.johnsnowlabs.partition.{HasBinaryReaderProperties, Partition} import com.johnsnowlabs.reader.util.{ImageParser, ImagePromptTemplate} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap} import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ @@ -113,6 +109,21 @@ class Reader2Image(override val uid: String) def setCustomPromptTemplate(value: String): this.type = set(promptTemplate, value) + val useEncodedImageBytes: Param[Boolean] = + new Param[Boolean]( + this, + "useEncodedImageBytes", + "If true, use the original encoded image bytes (e.g. JPEG, PNG). " + + "If false, decode the image into pixel data.") + + def setUseEncodedImageBytes(value: Boolean): this.type = set(useEncodedImageBytes, value) + + val outputPromptColumn: BooleanParam = + new BooleanParam(this, "outputPromptColumn", + "If true, outputs an additional 'prompt' column containing the text prompt as a Spark NLP Annotation.") + + def setOutputPromptColumn(value: Boolean): this.type = set(outputPromptColumn, value) + setDefault( contentType -> "", outputFormat -> "image", @@ -121,7 +132,10 @@ class Reader2Image(override val uid: String) promptTemplate -> "qwen2vl-chat", readAsImage -> true, customPromptTemplate -> "", - ignoreExceptions -> true) + ignoreExceptions -> true, + useEncodedImageBytes -> false, + outputPromptColumn -> false + ) override def transform(dataset: Dataset[_]): DataFrame = { validateRequiredParameters() @@ -133,7 +147,8 @@ class Reader2Image(override val uid: String) } else { partitionContent(partition, $(contentPath), isStringContent($(contentType)), dataset) } - if (!structuredDf.isEmpty) { + + val resultDf = if (!structuredDf.isEmpty) { val annotatedDf = structuredDf .withColumn( getOutputCol, @@ -143,6 +158,14 @@ class Reader2Image(override val uid: String) } else { structuredDf } + + if ($(outputPromptColumn)) { + resultDf.withColumn( + "prompt", + wrapPromptColumnMetadata(buildPromptAnnotationUdf(element_at(col(getOutputCol), 1)("text"))) + ) + } else resultDf + } override def partitionContent( @@ -293,6 +316,12 @@ class Reader2Image(override val uid: String) val imageFields = ImageIOUtils.bufferedImageToImageFields(decodedContent, origin) if (imageFields.isDefined) { + val resultBytes: Array[Byte] = if ($(useEncodedImageBytes)) { + binaryContentOpt.getOrElse(Array.emptyByteArray) + } else { + imageFields.get.data + } + Some( AnnotationImage( IMAGE, @@ -301,7 +330,7 @@ class Reader2Image(override val uid: String) imageFields.get.width, imageFields.get.nChannels, imageFields.get.mode, - imageFields.get.data, + resultBytes, metadata, buildPrompt)) } else { @@ -332,6 +361,18 @@ class Reader2Image(override val uid: String) } } + private val buildPromptAnnotationUdf = udf((text: String) => { + if (text == null) Seq.empty[Annotation] + else Seq(Annotation( + annotatorType = "document", + begin = 0, + end = text.length - 1, + result = text, + metadata = Map("source" -> "Reader2Image"), + embeddings = Array.emptyFloatArray + )) + }) + def afterAnnotate(dataset: DataFrame): DataFrame = { if ($(explodeDocs)) { dataset @@ -348,15 +389,30 @@ class Reader2Image(override val uid: String) override def copy(extra: ParamMap): Transformer = defaultCopy(extra) override def transformSchema(schema: StructType): StructType = { - val metadataBuilder: MetadataBuilder = new MetadataBuilder() - metadataBuilder.putString("annotatorType", outputAnnotatorType) - val outputFields = schema.fields :+ + val imageMetadataBuilder: MetadataBuilder = new MetadataBuilder() + imageMetadataBuilder.putString("annotatorType", outputAnnotatorType) + + val baseStruct = schema.add( StructField( getOutputCol, ArrayType(AnnotationImage.dataType), nullable = false, - metadataBuilder.build) - StructType(outputFields) + imageMetadataBuilder.build) + ) + + if ($(outputPromptColumn)) { + val promptMetadataBuilder = new MetadataBuilder() + promptMetadataBuilder.putString("annotatorType", AnnotatorType.DOCUMENT) + + baseStruct.add( + StructField( + "prompt", + ArrayType(Annotation.dataType), + nullable = true, + promptMetadataBuilder.build + ) + ) + } else baseStruct } override val outputAnnotatorType: AnnotatorType = IMAGE @@ -371,6 +427,16 @@ class Reader2Image(override val uid: String) col.as(getOutputCol, columnMetadata) } + private lazy val promptColumnMetadata: Metadata = { + val metadataBuilder = new MetadataBuilder() + metadataBuilder.putString("annotatorType", AnnotatorType.DOCUMENT) + metadataBuilder.build + } + + private def wrapPromptColumnMetadata(col: Column): Column = { + col.as("prompt", promptColumnMetadata) + } + protected def validateRequiredParameters(): Unit = { require( $(contentPath) != null && $(contentPath).trim.nonEmpty, diff --git a/src/test/scala/com/johnsnowlabs/reader/Reader2ImageTest.scala b/src/test/scala/com/johnsnowlabs/reader/Reader2ImageTest.scala index efee327acffdc1..e333ea3a61c079 100644 --- a/src/test/scala/com/johnsnowlabs/reader/Reader2ImageTest.scala +++ b/src/test/scala/com/johnsnowlabs/reader/Reader2ImageTest.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.reader import com.johnsnowlabs.nlp.annotators.SparkSessionTest import com.johnsnowlabs.nlp.annotators.cv.{Qwen2VLTransformer, SmolVLMTransformer} +import com.johnsnowlabs.nlp.annotators.seq2seq.AutoGGUFVisionModel import com.johnsnowlabs.nlp.{AnnotatorType, AssertAnnotations} import com.johnsnowlabs.tags.{FastTest, SlowTest} import org.apache.spark.ml.Pipeline @@ -408,30 +409,19 @@ class Reader2ImageTest extends AnyFlatSpec with SparkSessionTest { } it should "set custom prompt" taggedAs SlowTest in { - val customPrompt = "<|im_start|>{prompt}<|im_end|><|im_start|>assistant" val reader2Doc = new Reader2Image() .setContentPath(emailDirectory) .setOutputCol("image") .setUserMessage("Describe the image with 3 to 4 words.") - .setPromptTemplate("custom") - .setCustomPromptTemplate(customPrompt) - - val pipeline = new Pipeline().setStages(Array(reader2Doc)) - - val pipelineModel = pipeline.fit(emptyDataSet) - val imagesDf = pipelineModel.transform(emptyDataSet) - imagesDf.show() - imagesDf.select("image.text").show(truncate = false) - imagesDf.printSchema() val visualQAClassifier = Qwen2VLTransformer .pretrained() .setInputCols("image") .setOutputCol("answer") - val vlmPipeline = new Pipeline().setStages(Array(visualQAClassifier)) - val resultDf = vlmPipeline.fit(imagesDf).transform(imagesDf) + val vlmPipeline = new Pipeline().setStages(Array(reader2Doc, visualQAClassifier)) + val resultDf = vlmPipeline.fit(emptyDataSet).transform(emptyDataSet) resultDf.select("image.origin", "answer.result").show(truncate = false) @@ -450,10 +440,6 @@ class Reader2ImageTest extends AnyFlatSpec with SparkSessionTest { val pipelineModel = pipeline.fit(emptyDataSet) val imagesDf = pipelineModel.transform(emptyDataSet) - imagesDf.show() - imagesDf.select("image.text").show(truncate = false) - imagesDf.printSchema() - val visualQAClassifier = Qwen2VLTransformer .pretrained() .setInputCols("image") @@ -546,6 +532,101 @@ class Reader2ImageTest extends AnyFlatSpec with SparkSessionTest { assert(!resultDf.isEmpty) } + it should "integrate images output with VLM models" taggedAs SlowTest in { + val sourceFile = "SwitzerlandAlps.jpg" + val reader2Image = new Reader2Image() + .setContentPath(s"$imageDirectory/$sourceFile") + .setContentType("raw/image") + .setOutputCol("image") + + val pipeline = new Pipeline().setStages(Array(reader2Image)) + val pipelineModel = pipeline.fit(emptyDataSet) + val imagesDf = pipelineModel.transform(emptyDataSet) + imagesDf.show() + + val visualQAClassifier = Qwen2VLTransformer + .pretrained() + .setInputCols("image") + .setOutputCol("answer") + + val vlmPipeline = new Pipeline().setStages(Array(visualQAClassifier)) + val resultDf = vlmPipeline.fit(imagesDf).transform(imagesDf) + + resultDf.select("image.origin", "answer.result").show(truncate = false) + + assert(!resultDf.isEmpty) + } + + it should "work with AutoGGUFVisionModel using a prompt output column" taggedAs SlowTest in { + + + val sourceFile = "pdf-with-2images.pdf" + val reader2Image = new Reader2Image() + .setContentPath(s"$pdfDirectory/$sourceFile") + .setContentType("application/pdf") + .setOutputCol("image") + .setUseEncodedImageBytes(true) + .setUserMessage("Describe in a short and easy to understand sentence what you see in the image.") + .setOutputPromptColumn(true) + + val autoGgufModel: AutoGGUFVisionModel = AutoGGUFVisionModel + .pretrained() + .setInputCols("prompt", "image") + .setOutputCol("completions") + .setBatchSize(2) + .setNGpuLayers(99) + .setNCtx(4096) + .setMinKeep(0) + .setMinP(0.05f) + .setNPredict(40) + .setPenalizeNl(true) + .setRepeatPenalty(1.18f) + .setTemperature(0.05f) + .setTopK(40) + .setTopP(0.95f) + + val pipeline = new Pipeline().setStages(Array(reader2Image, autoGgufModel)) + val pipelineModel = pipeline.fit(emptyDataSet) + val completionDf = pipelineModel.transform(emptyDataSet) + + completionDf.select("fileName", "completions.result").show(truncate = false) + } + + it should "work with AutoGGUFVisionModel using a prompt output column and PDF files" taggedAs SlowTest in { + + val imageFile = "SwitzerlandAlps.jpg" + val imagePath = s"$imageDirectory/$imageFile" + val reader2Image = new Reader2Image() + .setContentPath(imagePath) + .setContentType("image/raw") + .setOutputCol("image") + .setUseEncodedImageBytes(true) + .setUserMessage("Describe in a short and easy to understand sentence what you see in the image.") + .setOutputPromptColumn(true) + + val autoGgufModel: AutoGGUFVisionModel = AutoGGUFVisionModel + .pretrained() + .setInputCols("prompt", "image") + .setOutputCol("completions") + .setBatchSize(2) + .setNGpuLayers(99) + .setNCtx(4096) + .setMinKeep(0) + .setMinP(0.05f) + .setNPredict(40) + .setPenalizeNl(true) + .setRepeatPenalty(1.18f) + .setTemperature(0.05f) + .setTopK(40) + .setTopP(0.95f) + + val pipeline = new Pipeline().setStages(Array(reader2Image, autoGgufModel)) + val pipelineModel = pipeline.fit(emptyDataSet) + val completionDf = pipelineModel.transform(emptyDataSet) + + completionDf.select("fileName", "completions.result").show(truncate = false) + } + def getSupportedFiles(dirPath: String): Seq[String] = { val supportedExtensions = Seq(".html", ".htm", ".md", "doc", "docx")