Skip to content

Commit d0309be

Browse files
armandmasseaugitElenaKhaustovaankatiyarmerelcht
authored
feat(datasets): add experimental chromadb dataset (#1216)
* docs(datasets): chromadb markdown docs Signed-off-by: Armand <[email protected]> * docs(datasets): chromadb mkdocs Signed-off-by: Armand <[email protected]> * chore(datasets): chromadb deps update Signed-off-by: Armand <[email protected]> * feat(datasets): chromadb implementation file Signed-off-by: Armand <[email protected]> * test(datasets): chromadb test files Signed-off-by: Armand <[email protected]> * style(datasets): precommit hook manual run Signed-off-by: Armand <[email protected]> * style(datasets): end of file space Signed-off-by: Armand <[email protected]> * fix(datasets): better handling of client Signed-off-by: Armand <[email protected]> * test(datasets): update the test with the fix Signed-off-by: Armand <[email protected]> * fix(datasets): fix load method query Signed-off-by: Armand <[email protected]> * test(datasets): update the test to work Signed-off-by: Armand <[email protected]> * chore(datasets): add dataseterror catching Co-authored-by: ElenaKhaustova <[email protected]> Signed-off-by: Armand Masseau <[email protected]> * refactor(datasets): suggestions on chromadb implementation Signed-off-by: Armand <[email protected]> * test(datasets): adapt tests to new chromadb implementation Signed-off-by: Armand <[email protected]> * docs(datasets): add chromadb to nav Signed-off-by: Armand <[email protected]> * chore(datasets): update chromadb version Signed-off-by: Armand <[email protected]> * chore(datasets): place uuid import at the beginning Signed-off-by: Armand <[email protected]> * refactor(datasets): handle the exception of inexistent chroma collection Co-authored-by: ElenaKhaustova <[email protected]> Signed-off-by: Armand Masseau <[email protected]> * feat(datasets): add the possibility to query the db without loading it all Signed-off-by: Armand <[email protected]> * docs(datasets): clarify the create_if_missing return behavior Signed-off-by: Armand <[email protected]> * chore(datasets): add typing in chromadb class Signed-off-by: Armand <[email protected]> * Move added dataset to latest release notes section Signed-off-by: Merel Theisen <[email protected]> --------- Signed-off-by: Armand <[email protected]> Signed-off-by: Armand Masseau <[email protected]> Signed-off-by: Merel Theisen <[email protected]> Co-authored-by: ElenaKhaustova <[email protected]> Co-authored-by: Ankita Katiyar <[email protected]> Co-authored-by: Merel Theisen <[email protected]> Co-authored-by: Merel Theisen <[email protected]>
1 parent dc7d36a commit d0309be

File tree

9 files changed

+586
-0
lines changed

9 files changed

+586
-0
lines changed

kedro-datasets/RELEASE.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@
88
|-----------------------|----------------------------------------------------------------------------------------------------------------|-------------------------|
99
| `spark.SparkDatasetV2` | A Spark dataset with Spark Connect, Databricks Connect support, and automatic Pandas-to-Spark conversion | `kedro_datasets.spark` |
1010

11+
- Added the following new **experimental** datasets:
12+
13+
| Type | Description | Location |
14+
|------------------------------------|-----------------------------------------------------------------|-----------------------------------------|
15+
| `chromadb.ChromaDBDataset` | A dataset for loading and saving data to ChromaDB vector database collections | `kedro_datasets_experimental.chromadb` |
16+
1117
## Bug fixes and other changes
1218
## Community contributions
1319

20+
- [Armand Masseau](https://github.com/armandmasseaugit)
21+
1422
# Release 9.0.0
1523

1624
## Major features and improvements
@@ -42,6 +50,7 @@
4250
| `langchain.LangChainPromptDataset` | Kedro dataset for loading LangChain prompts | `kedro_datasets_experimental.langchain` |
4351
| `pypdf.PDFDataset` | Kedro dataset to read PDF files and extract text using pypdf | `kedro_datasets_experimental.pypdf` |
4452
| `langfuse.LangfusePromptDataset` | Kedro dataset for managing Langfuse prompts | `kedro_datasets_experimental.langfuse` |
53+
| `chromadb.ChromaDBDataset` | A dataset for loading and saving data to ChromaDB vector database collections | `kedro_datasets_experimental.chromadb` |
4554
| `opik.OpikPromptDataset` | A dataset to provide Opik integration for handling prompts | `kedro_datasets_experimental.opik` |
4655
| `opik.OpikTraceDataset` | Kedro dataset to provide Opik tracing clients and callbacks | `kedro_datasets_experimental.opik` |
4756

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# ChromaDBDataset
2+
3+
`ChromaDBDataset` loads and saves data to ChromaDB collections.
4+
5+
::: kedro_datasets_experimental.chromadb.ChromaDBDataset
6+
options:
7+
members: true
8+
show_source: true

kedro-datasets/docs/api/kedro_datasets_experimental/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
Name | Description
88
-----|------------
9+
[chromadb.ChromaDBDataset](chromadb.ChromaDBDataset.md) | ``ChromaDBDataset`` loads and saves data to ChromaDB vector database collections.
910
[databricks.ExternalTableDataset](databricks.ExternalTableDataset.md) | ``ExternalTableDataset`` implementation to access external tables in Databricks.
1011
[langchain.LangChainPromptDataset](langchain.LangChainPromptDataset.md) | ``LangChainPromptDataset`` loads a `langchain` prompt template.
1112
[langfuse.LangfusePromptDataset](langfuse.LangfusePromptDataset.md) | ``LangfusePromptDataset`` provides a seamless integration between local prompt files (JSON/YAML) and Langfuse prompt management, supporting version control, labeling, and different synchronization policies.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""``AbstractDataset`` implementation for ChromaDB collections."""
2+
3+
from typing import Any
4+
5+
import lazy_loader as lazy
6+
7+
try:
8+
from .chromadb_dataset import ChromaDBDataset
9+
except (ImportError, RuntimeError):
10+
# For documentation builds that might fail due to dependency issues
11+
# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
12+
ChromaDBDataset: Any
13+
14+
__getattr__, __dir__, __all__ = lazy.attach(
15+
__name__, submod_attrs={"chromadb_dataset": ["ChromaDBDataset"]}
16+
)
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
"""``ChromaDBDataset`` loads and saves data from/to ChromaDB collections."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
import chromadb
8+
from chromadb.api.models.Collection import Collection
9+
from chromadb.errors import NotFoundError
10+
from kedro.io.core import AbstractDataset, DatasetError
11+
12+
13+
class ChromaDBDataset(AbstractDataset[dict[str, Any], dict[str, Any]]):
14+
"""``ChromaDBDataset`` loads and saves data from/to ChromaDB collections.
15+
16+
ChromaDB is a vector database for building AI applications. This dataset allows you to
17+
interact with ChromaDB collections for storing and retrieving documents with embeddings.
18+
19+
Examples:
20+
Using the [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/):
21+
22+
```yaml
23+
my_collection:
24+
type: chromadb.ChromaDBDataset
25+
collection_name: "documents"
26+
client_type: "persistent"
27+
client_settings:
28+
path: "./chroma_db"
29+
```
30+
31+
Using the [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/):
32+
33+
>>> from kedro_datasets_experimental.chromadb import ChromaDBDataset
34+
>>>
35+
>>> # Save data to ChromaDB
36+
>>> data = {
37+
... "documents": ["This is a document", "This is another document"],
38+
... "metadatas": [{"type": "text"}, {"type": "text"}],
39+
... "ids": ["doc1", "doc2"]
40+
... }
41+
>>> dataset = ChromaDBDataset(collection_name="test_collection")
42+
>>> dataset.save(data)
43+
>>>
44+
>>> # Load data from ChromaDB
45+
>>> loaded_data = dataset.load()
46+
>>> print(loaded_data["documents"]) # ['This is a document', 'This is another document']
47+
>>>
48+
>>> # Query for similar vectors (efficient for large datasets)
49+
>>> query_dataset = ChromaDBDataset(
50+
... collection_name="documents",
51+
... load_args={
52+
... "query_texts": ["machine learning"],
53+
... "n_results": 5,
54+
... "include": ["documents", "metadatas", "distances"]
55+
... }
56+
... )
57+
>>> results = query_dataset.load() # Returns top-5 similar documents
58+
59+
"""
60+
# Attribute annotations for IDEs / type-checkers (instance values set in __init__)
61+
_client: chromadb.Client | None
62+
_collection: Collection | None
63+
64+
def __init__( # noqa: PLR0913
65+
self,
66+
*,
67+
collection_name: str,
68+
client_type: str = "ephemeral",
69+
client_settings: dict[str, Any] | None = None,
70+
load_args: dict[str, Any] | None = None,
71+
save_args: dict[str, Any] | None = None,
72+
metadata: dict[str, Any] | None = None,
73+
) -> None:
74+
"""Creates a new instance of ``ChromaDBDataset``.
75+
76+
Args:
77+
collection_name: The name of the ChromaDB collection.
78+
client_type: Type of ChromaDB client. Options: "ephemeral", "persistent", "http".
79+
Defaults to "ephemeral".
80+
client_settings: Settings for the ChromaDB client. For "persistent", use {"path": "/path/to/db"}.
81+
For "http", use {"host": "localhost", "port": 8000}.
82+
load_args: Additional arguments for loading data from ChromaDB collection.
83+
Can include "where", "where_document", "include", "n_results", etc.
84+
For vector similarity queries, use:
85+
- "query_embeddings": List of embeddings to query for similarity
86+
- "query_texts": List of texts to query for similarity
87+
- "n_results": Number of results to return (default: 10)
88+
- "where": Metadata filter conditions
89+
- "where_document": Document content filter conditions
90+
save_args: Additional arguments for saving data to ChromaDB collection.
91+
Can include "embeddings" if you want to provide custom embeddings.
92+
metadata: Any arbitrary metadata.
93+
This is ignored by Kedro, but may be consumed by users or external plugins.
94+
"""
95+
self._collection_name = collection_name
96+
self._client_type = client_type
97+
self._client_settings = client_settings or {}
98+
self._load_args = load_args or {}
99+
self._save_args = save_args or {}
100+
self.metadata = metadata
101+
# Initialize instance attributes (actual annotations are at class-level)
102+
self._client = None
103+
self._collection = None
104+
105+
def _create_client(self) -> chromadb.Client:
106+
"""Create ChromaDB client based on configuration."""
107+
if self._client_type == "ephemeral":
108+
return chromadb.EphemeralClient()
109+
elif self._client_type == "persistent":
110+
path = self._client_settings.get("path", "./chroma_db")
111+
return chromadb.PersistentClient(path=path, **{
112+
k: v for k, v in self._client_settings.items() if k != "path"
113+
})
114+
elif self._client_type == "http":
115+
host = self._client_settings.get("host", "localhost")
116+
port = self._client_settings.get("port", 8000)
117+
return chromadb.HttpClient(host=host, port=port, **{
118+
k: v for k, v in self._client_settings.items() if k not in ["host", "port"]
119+
})
120+
else:
121+
raise DatasetError(
122+
f"Unsupported client_type: {self._client_type}. "
123+
f"Must be one of: 'ephemeral', 'persistent', 'http'"
124+
)
125+
126+
def _get_client(self) -> chromadb.Client:
127+
"""Get or create the ChromaDB client."""
128+
if self._client is None:
129+
self._client = self._create_client()
130+
return self._client
131+
132+
def _get_collection(self, create_if_missing: bool = True) -> Collection | None:
133+
"""Get or create the ChromaDB collection.
134+
135+
Args:
136+
create_if_missing: If True, creates the collection if it doesn't exist.
137+
If False, returns None when collection is not found.
138+
139+
Returns:
140+
Collection object if found/created, None if not found and create_if_missing=False.
141+
"""
142+
if self._collection is None:
143+
client = self._get_client()
144+
try:
145+
self._collection = client.get_collection(name=self._collection_name)
146+
except NotFoundError:
147+
if create_if_missing:
148+
# Collection doesn't exist, create it
149+
self._collection = client.create_collection(name=self._collection_name)
150+
else:
151+
# Don't create collection, return None instead of raising
152+
return None
153+
return self._collection
154+
155+
def _describe(self) -> dict[str, Any]:
156+
"""Returns a dictionary describing the dataset configuration."""
157+
return {
158+
"collection_name": self._collection_name,
159+
"client_type": self._client_type,
160+
"client_settings": self._client_settings,
161+
"load_args": self._load_args,
162+
"save_args": self._save_args,
163+
}
164+
165+
def load(self) -> dict[str, Any]:
166+
"""Loads data from the ChromaDB collection.
167+
168+
Returns:
169+
dict[str, Any]: A dictionary containing the collection data with keys:
170+
- "documents": List of document texts
171+
- "metadatas": List of metadata dictionaries
172+
- "ids": List of document IDs
173+
- "embeddings": List of embeddings (if included)
174+
"""
175+
collection = self._get_collection(create_if_missing=False)
176+
177+
# If collection doesn't exist, return empty result rather than creating it
178+
if collection is None:
179+
return {"documents": [], "metadatas": [], "ids": [], "embeddings": []}
180+
181+
# Prepare load arguments
182+
load_args = {
183+
"include": ["documents", "metadatas", "embeddings"],
184+
**self._load_args
185+
}
186+
187+
try:
188+
# Use query() for vector similarity search or filtering
189+
if any(key in load_args for key in ["query_embeddings", "query_texts", "where", "where_document"]):
190+
# Vector similarity query - more efficient for large datasets
191+
if "n_results" not in load_args:
192+
load_args["n_results"] = 10 # Default limit for queries
193+
result = collection.query(**load_args)
194+
else:
195+
# Use get() for retrieving all documents (not recommended for large collections)
196+
if "n_results" in load_args:
197+
# Convert n_results to limit for get() method
198+
load_args["limit"] = load_args.pop("n_results")
199+
result = collection.get(**load_args)
200+
201+
return {
202+
"documents": result.get("documents", []),
203+
"metadatas": result.get("metadatas", []),
204+
"ids": result.get("ids", []),
205+
"embeddings": result.get("embeddings", [])
206+
}
207+
except Exception as e:
208+
raise DatasetError(
209+
f"Failed to load data from ChromaDB collection '{self._collection_name}': {e}"
210+
) from e
211+
212+
def save(self, data: dict[str, Any]) -> None:
213+
"""Saves data to the ChromaDB collection.
214+
215+
Args:
216+
data: A dictionary containing the data to save. Expected keys:
217+
- "documents": List of document texts (required)
218+
- "ids": List of document IDs (required)
219+
- "metadatas": List of metadata dictionaries (optional)
220+
- "embeddings": List of embeddings (optional, will be auto-generated if not provided)
221+
"""
222+
if not isinstance(data, dict):
223+
raise DatasetError(f"Data must be a dictionary, got {type(data)}")
224+
225+
if "documents" not in data or "ids" not in data:
226+
raise DatasetError("Data must contain 'documents' and 'ids' keys")
227+
228+
collection = self._get_collection(create_if_missing=True)
229+
230+
if collection is None:
231+
raise DatasetError(f"Failed to access or create ChromaDB collection '{self._collection_name}'")
232+
233+
try:
234+
# Prepare the data for ChromaDB
235+
add_kwargs = {
236+
"documents": data["documents"],
237+
"ids": data["ids"],
238+
**self._save_args
239+
}
240+
241+
# Add optional fields if present
242+
if "metadatas" in data:
243+
add_kwargs["metadatas"] = data["metadatas"]
244+
if "embeddings" in data:
245+
add_kwargs["embeddings"] = data["embeddings"]
246+
247+
# Add documents to collection
248+
collection.add(**add_kwargs)
249+
250+
except Exception as e:
251+
raise DatasetError(
252+
f"Failed to save data to ChromaDB collection '{self._collection_name}': {e}"
253+
) from e
254+
255+
def exists(self) -> bool:
256+
"""Checks if the collection exists and contains data."""
257+
try:
258+
collection = self._collection or self._get_collection(create_if_missing=False)
259+
# In case both return None
260+
if collection is None:
261+
return False
262+
return collection.count() > 0
263+
except Exception:
264+
return False
265+
# Use the same collection instance if we already have it
266+
if self._collection is not None:
267+
count = self._collection.count()
268+
return count > 0
269+
270+
# Otherwise try to get the collection from the client
271+
collection = self._get_collection(create_if_missing=False)
272+
count = collection.count()
273+
return count > 0
274+
except Exception:
275+
return False

kedro-datasets/kedro_datasets_experimental/tests/chromadb/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)