Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions ravendb/documents/operations/ai/chunking_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ class ChunkingMethod(Enum):
HTML_STRIP = "HtmlStrip"


# Methods that support overlap tokens
# Methods that support overlap tokens. Mirrors the server (TextChunker.cs) and the C#/Node clients:
# only the two *paragraph* methods consume OverlapTokens; every other method ignores it.
METHODS_SUPPORTING_OVERLAP_TOKENS = {
ChunkingMethod.PLAIN_TEXT_SPLIT,
ChunkingMethod.PLAIN_TEXT_SPLIT_LINES,
ChunkingMethod.PLAIN_TEXT_SPLIT_PARAGRAPHS,
ChunkingMethod.MARK_DOWN_SPLIT_PARAGRAPHS,
}


Expand Down Expand Up @@ -44,8 +44,8 @@ def __init__(
def from_json(cls, json_dict: Dict[str, Any]) -> "ChunkingOptions":
return cls(
chunking_method=ChunkingMethod(json_dict["ChunkingMethod"]),
max_tokens_per_chunk=json_dict.get("MaxTokensPerChunk", None),
overlap_tokens=json_dict.get("OverlapTokens", None),
max_tokens_per_chunk=json_dict.get("MaxTokensPerChunk", 512),
overlap_tokens=json_dict.get("OverlapTokens", 0),
context_prefix=json_dict.get("ContextPrefix", None),
)

Expand Down Expand Up @@ -84,10 +84,8 @@ def validate(self, source: str, errors: List[str]) -> None:
errors.append(f"{source}: OverlapTokens cannot be greater than MaxTokensPerChunk.")

if self.overlap_tokens > 0 and self.chunking_method not in METHODS_SUPPORTING_OVERLAP_TOKENS:
errors.append(
f"{source}: OverlapTokens is only supported for PlainTextSplit, "
f"PlainTextSplitLines, and PlainTextSplitParagraphs chunking methods."
)
supported = ", ".join(sorted(method.value for method in METHODS_SUPPORTING_OVERLAP_TOKENS))
errors.append(f"{source}: OverlapTokens is only supported for the following chunking methods: {supported}.")

@staticmethod
def are_equal(left: Optional["ChunkingOptions"], right: Optional["ChunkingOptions"]) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,14 @@ def validate(

if not has_paths and not has_transformation:
errors.append("Either EmbeddingsPathConfigurations or EmbeddingsTransformation must be provided")
elif has_paths and has_transformation:
errors.append("Cannot specify both EmbeddingsPathConfigurations and EmbeddingsTransformation")

# Validate each path's chunking options (mirrors the server / C# client; both may be set).
if self.embeddings_path_configurations:
for path_configuration in self.embeddings_path_configurations:
if path_configuration.chunking_options is not None:
path_configuration.chunking_options.validate(path_configuration.path, errors)
else:
errors.append(f"Path '{path_configuration.path}': ChunkingOptions must be provided.")

# Validate transformation if provided
if has_transformation:
Expand Down
8 changes: 6 additions & 2 deletions ravendb/documents/session/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,9 @@ def _order_by_distance_wkt(

round_factor_parameter_name = None if round_factor == 0 else self.__add_query_parameter(round_factor)
self._order_by_tokens.append(
OrderByToken.create_distance_ascending_wkt(field_name, shape_wkt, round_factor_parameter_name, nulls)
OrderByToken.create_distance_ascending_wkt(
field_name, self.__add_query_parameter(shape_wkt), round_factor_parameter_name, nulls
)
)

def _order_by_distance_descending(
Expand Down Expand Up @@ -1627,7 +1629,9 @@ def _order_by_distance_descending_wkt(
round_factor_parameter_name = None if round_factor == 0 else self.__add_query_parameter(round_factor)

self._order_by_tokens.append(
OrderByToken.create_distance_descending_wkt(field_name, shape_wkt, round_factor_parameter_name, nulls)
OrderByToken.create_distance_descending_wkt(
field_name, self.__add_query_parameter(shape_wkt), round_factor_parameter_name, nulls
)
)

def _init_sync(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def create_distance_ascending_wkt(
nulls: NullsOrdering = NullsOrdering.DEFAULT,
) -> OrderByToken:
return cls(
f"spatial.distance({field_name}), "
f"spatial.distance({field_name}, "
f"spatial.wkt(${wkt_parameter_name})"
f"{'' if round_factor_parameter_name is None else ', $' + round_factor_parameter_name})",
False,
Expand Down
30 changes: 30 additions & 0 deletions ravendb/tests/embeddings_generation_tests/test_chunking_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@ def test_no_chunking_marker_bypasses_budget_validation(self):
options.validate("source", errors)
self.assertEqual([], errors)

def test_overlap_allowed_only_on_paragraph_methods(self):
# Matches the server / C# / Node clients: overlap is consumed only by the two paragraph methods.
for method in (ChunkingMethod.PLAIN_TEXT_SPLIT_PARAGRAPHS, ChunkingMethod.MARK_DOWN_SPLIT_PARAGRAPHS):
errors = []
ChunkingOptions(method, 100, 10).validate("source", errors)
self.assertEqual([], errors, method)

def test_overlap_rejected_on_non_paragraph_methods(self):
for method in (
ChunkingMethod.PLAIN_TEXT_SPLIT,
ChunkingMethod.PLAIN_TEXT_SPLIT_LINES,
ChunkingMethod.MARK_DOWN_SPLIT_LINES,
ChunkingMethod.HTML_STRIP,
):
errors = []
ChunkingOptions(method, 100, 10).validate("source", errors)
self.assertEqual(1, len(errors), method)
self.assertIn("OverlapTokens", errors[0])

def test_from_json_missing_int_keys_use_csharp_defaults(self):
# Missing MaxTokensPerChunk/OverlapTokens must default to 512/0 (matching C#'s non-nullable
# int initializers), not None. None would crash validate() with a TypeError.
options = ChunkingOptions.from_json({"ChunkingMethod": "HtmlStrip"})
self.assertEqual(512, options.max_tokens_per_chunk)
self.assertEqual(0, options.overlap_tokens)

errors = []
options.validate("source", errors) # must not raise TypeError on the defaults
self.assertEqual([], errors)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest

from ravendb.documents.operations.ai.chunking_options import ChunkingOptions, ChunkingMethod
from ravendb.documents.operations.ai.embedding_path_configuration import EmbeddingPathConfiguration
from ravendb.documents.operations.ai.embeddings_generation_configuration import EmbeddingsGenerationConfiguration
from ravendb.documents.operations.ai.embeddings_transformation import EmbeddingsTransformation


def _config(**overrides) -> EmbeddingsGenerationConfiguration:
# A base config that passes every check except the paths/transformation logic under test.
base = dict(
name="emb",
identifier="emb",
collection="Docs",
connection_string_name="cs",
chunking_options_for_querying=ChunkingOptions(ChunkingMethod.HTML_STRIP, 100, 0),
)
base.update(overrides)
return EmbeddingsGenerationConfiguration(**base)


def _validate(config) -> list:
return config.validate(validate_name=False, validate_identifier=False)


class TestEmbeddingsGenerationConfigurationValidate(unittest.TestCase):
def test_paths_and_transformation_together_are_allowed(self):
# C# has no mutual-exclusivity rule; setting both must NOT be rejected client-side.
config = _config(
embeddings_path_configurations=[
EmbeddingPathConfiguration("Name", ChunkingOptions(ChunkingMethod.HTML_STRIP, 100, 0))
],
embeddings_transformation=EmbeddingsTransformation(script="embeddings.generate(this.Name)"),
)
errors = _validate(config)
self.assertNotIn("Cannot specify both EmbeddingsPathConfigurations and EmbeddingsTransformation", errors)
self.assertEqual([], errors)

def test_path_missing_chunking_options_is_rejected(self):
# Mirrors C#: each path must carry ChunkingOptions.
config = _config(embeddings_path_configurations=[EmbeddingPathConfiguration("Name", None)])
errors = _validate(config)
self.assertIn("Path 'Name': ChunkingOptions must be provided.", errors)

def test_path_invalid_chunking_options_is_rejected(self):
# Each path's ChunkingOptions is validated with the path as the error source.
config = _config(
embeddings_path_configurations=[
EmbeddingPathConfiguration("Name", ChunkingOptions(ChunkingMethod.HTML_STRIP, 0, 0))
]
)
errors = _validate(config)
self.assertTrue(
any("Name" in e and "MaxTokensPerChunk" in e for e in errors),
f"expected a per-path MaxTokensPerChunk error, got: {errors}",
)

def test_path_with_valid_chunking_options_passes(self):
config = _config(
embeddings_path_configurations=[
EmbeddingPathConfiguration("Name", ChunkingOptions(ChunkingMethod.HTML_STRIP, 100, 0))
]
)
self.assertEqual([], _validate(config))


if __name__ == "__main__":
unittest.main()
22 changes: 22 additions & 0 deletions ravendb/tests/embeddings_generation_tests/test_tasks_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UpdateEmbeddingsGenerationOperation,
)
from ravendb.documents.operations.ai.embedded_settings import EmbeddedSettings
from ravendb.documents.operations.ai.embeddings_transformation import EmbeddingsTransformation
from ravendb.documents.operations.connection_string.put_connection_string_operation import PutConnectionStringOperation
from ravendb.documents.operations.connection_string.remove_connection_string_operation import (
RemoveConnectionStringOperation,
Expand Down Expand Up @@ -214,6 +215,27 @@ def test_embeddings_generation_configuration_without_chunking_options_should_pro
self.assertIn("PostContent", error_message)
self.assertIn("ChunkingOptions", error_message)

def test_can_add_task_with_both_paths_and_transformation(self):
"""The server imposes no mutual-exclusivity between paths and transformation, so a config
that sets BOTH must be accepted (the client must not reject it pre-send either)."""
config = EmbeddingsGenerationConfiguration(
name="ai-task-both",
connection_string_name=self.CONNECTION_STRING_NAME,
embeddings_path_configurations=[
EmbeddingPathConfiguration(path="PostContent", chunking_options=self.DEFAULT_CHUNKING_OPTIONS),
],
embeddings_transformation=EmbeddingsTransformation(
script="embeddings.generate(this.Comments)",
chunking_options=self.DEFAULT_CHUNKING_OPTIONS,
),
collection="Posts",
chunking_options_for_querying=self.DEFAULT_CHUNKING_OPTIONS,
)

add_result = self.store.maintenance.send(AddEmbeddingsGenerationOperation(config))
self.assertIsNotNone(add_result.task_id)
self._created_task_ids.append(add_result.task_id)

def test_database_record_contains_embeddings_generations(self):
config = self._create_valid_config()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

from ravendb.documents.session.tokens.query_tokens.definitions import OrderByToken


def _rql(token: OrderByToken) -> str:
writer = []
token.write_to(writer)
return "".join(writer)


class TestOrderByDistanceWktRql(unittest.TestCase):
"""The WKT ascending distance factory must nest spatial.wkt(...) inside spatial.distance(field, ...),
matching C# (OrderByToken.cs) and its three sibling factories. The bug emitted a stray ')' right
after the field name: 'spatial.distance(loc), spatial.wkt(...)'."""

def test_distance_ascending_wkt_nests_wkt_inside_distance(self):
rql = _rql(OrderByToken.create_distance_ascending_wkt("loc", "p0", None))
self.assertIn("spatial.distance(loc, spatial.wkt($p0))", rql)
self.assertNotIn("spatial.distance(loc),", rql)

def test_distance_ascending_wkt_matches_descending_structure(self):
asc = _rql(OrderByToken.create_distance_ascending_wkt("loc", "p0", None))
desc = _rql(OrderByToken.create_distance_descending_wkt("loc", "p0", None))
# Identical apart from the trailing descending marker.
self.assertEqual(asc, desc.replace(" desc", ""))

def test_distance_ascending_wkt_with_round_factor(self):
rql = _rql(OrderByToken.create_distance_ascending_wkt("loc", "p0", "p1"))
self.assertIn("spatial.distance(loc, spatial.wkt($p0), $p1)", rql)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def test_order_by_distance_descending_with_dynamic_point_field(self):
self.assertGreaterEqual(len(results), 2)
self.assertEqual("geo/2", results[0].Id)

def test_order_by_distance_wkt_ascending_with_dynamic_point_field(self):
# Regression for the create_distance_ascending_wkt RQL bug: a stray ')' produced
# 'spatial.distance(<field>), spatial.wkt(...)', which the server rejects at parse time.
# WKT is "POINT(longitude latitude)"; the point below is geo/1 (Greenwich).
with self.store.open_session() as s:
results = list(
s.query(object_type=_Geo).order_by_distance_wkt(PointField("lat", "lng"), "POINT(0.0015 51.4779)")
)
self.assertGreaterEqual(len(results), 2)
self.assertEqual("geo/1", results[0].Id)

def test_order_by_distance_wkt_descending_with_dynamic_point_field(self):
# Same WKT path, descending: the farthest document (New York) comes first.
with self.store.open_session() as s:
results = list(
s.query(object_type=_Geo).order_by_distance_descending_wkt(
PointField("lat", "lng"), "POINT(0.0015 51.4779)"
)
)
self.assertGreaterEqual(len(results), 2)
self.assertEqual("geo/2", results[0].Id)


if __name__ == "__main__":
unittest.main()
Loading