From 91e94cf9335f76f1e7af9cbbe0df825fa2fdd6fd Mon Sep 17 00:00:00 2001 From: Harsh Kumar Date: Tue, 3 Mar 2026 17:35:51 +0530 Subject: [PATCH 1/2] Add Vertex AI Ranking API as reranker option Add support for Google Vertex AI Ranking API alongside the existing HuggingFace CrossEncoder reranker. The reranker type is controlled by the RERANKER_TYPE env var (default: HF). When RERANKER_TYPE=VERTEX_AI, uses langchain-google-community's VertexAIRank as a drop-in replacement for CrossEncoderReranker in ContextualCompressionRetriever. Signed-off-by: Harsh Kumar Patwa Signed-off-by: Harsh Kumar --- backend/.env.example | 5 + backend/pyproject.toml | 1 + backend/src/chains/hybrid_retriever_chain.py | 34 +++++- backend/tests/test_hybrid_retriever_chain.py | 115 ++++++++++++++++++- 4 files changed, 150 insertions(+), 5 deletions(-) diff --git a/backend/.env.example b/backend/.env.example index 5ce3d60e..aeb9606d 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -27,6 +27,11 @@ GOOGLE_EMBEDDINGS=gemini-embedding-001 HF_EMBEDDINGS=thenlper/gte-large HF_RERANKER=BAAI/bge-reranker-base +# Reranker type: 'HF' for HuggingFace CrossEncoder, 'VERTEX_AI' for Google Vertex AI Ranking API +RERANKER_TYPE=HF +VERTEX_AI_PROJECT_ID= +VERTEX_AI_LOCATION=global + # FAISS database path FAISS_DB_PATH=./.faissdb/faiss_index diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9f59fa37..49a23006 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "huggingface-hub[cli]==0.34.4", "langchain==0.3.27", "langchain-community==0.3.27", + "langchain-google-community[vertexaisearch]>=2.0.0", "langchain-google-genai==2.1.9", "langchain-google-vertexai==2.0.28", "langchain-huggingface==0.3.1", diff --git a/backend/src/chains/hybrid_retriever_chain.py b/backend/src/chains/hybrid_retriever_chain.py index 1b68c14f..7dd2ae7f 100644 --- a/backend/src/chains/hybrid_retriever_chain.py +++ b/backend/src/chains/hybrid_retriever_chain.py @@ -1,4 +1,5 @@ import os +import logging from typing import Optional, Union, Any from langchain.retrievers import EnsembleRetriever @@ -121,10 +122,35 @@ def create_hybrid_retriever(self) -> None: ) if self.contextual_rerank: - compressor = CrossEncoderReranker( - model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name), - top_n=self.search_k, - ) + reranker_type = os.getenv("RERANKER_TYPE", "HF").upper() + + if reranker_type == "VERTEX_AI": + from langchain_google_community.vertex_rank import VertexAIRank + + project_id = os.getenv("VERTEX_AI_PROJECT_ID", "") + location_id = os.getenv("VERTEX_AI_LOCATION", "global") + + if not project_id: + raise ValueError( + "VERTEX_AI_PROJECT_ID must be set when using RERANKER_TYPE=VERTEX_AI" + ) + + compressor = VertexAIRank( + project_id=project_id, + location_id=location_id, + ranking_config="default_ranking_config", + top_n=self.search_k, + ) + logging.info("Using Vertex AI reranker") + else: + compressor = CrossEncoderReranker( + model=HuggingFaceCrossEncoder( + model_name=self.reranking_model_name + ), + top_n=self.search_k, + ) + logging.info("Using HuggingFace CrossEncoder reranker") + self.retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=ensemble_retriever ) diff --git a/backend/tests/test_hybrid_retriever_chain.py b/backend/tests/test_hybrid_retriever_chain.py index f2061ea8..99aab4f3 100644 --- a/backend/tests/test_hybrid_retriever_chain.py +++ b/backend/tests/test_hybrid_retriever_chain.py @@ -141,6 +141,7 @@ def test_create_hybrid_retriever_with_provided_vector_db( assert chain.retriever == mock_ensemble_instance + @patch.dict("os.environ", {"RERANKER_TYPE": "HF"}) @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain") @patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain") @patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain") @@ -158,7 +159,7 @@ def test_create_hybrid_retriever_with_contextual_rerank( mock_mmr_chain, mock_sim_chain, ): - """Test creating hybrid retriever with contextual reranking enabled.""" + """Test creating hybrid retriever with HF contextual reranking enabled.""" mock_vector_db = Mock() mock_vector_db.processed_docs = [Mock(), Mock()] @@ -209,6 +210,118 @@ def test_create_hybrid_retriever_with_contextual_rerank( assert chain.retriever == mock_compression_instance + @patch.dict( + "os.environ", + { + "RERANKER_TYPE": "VERTEX_AI", + "VERTEX_AI_PROJECT_ID": "test-project", + "VERTEX_AI_LOCATION": "global", + }, + ) + @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain") + @patch("src.chains.hybrid_retriever_chain.EnsembleRetriever") + @patch("src.chains.hybrid_retriever_chain.ContextualCompressionRetriever") + def test_create_hybrid_retriever_with_vertex_ai_rerank( + self, + mock_compression, + mock_ensemble, + mock_bm25_chain, + mock_mmr_chain, + mock_sim_chain, + ): + """Test creating hybrid retriever with Vertex AI reranking enabled.""" + mock_vector_db = Mock() + mock_vector_db.processed_docs = [Mock(), Mock()] + + chain = HybridRetrieverChain( + vector_db=mock_vector_db, + contextual_rerank=True, + search_k=5, + ) + + # Setup mocks + mock_sim_instance = Mock() + mock_sim_instance.retriever = Mock() + mock_sim_chain.return_value = mock_sim_instance + + mock_mmr_instance = Mock() + mock_mmr_instance.retriever = Mock() + mock_mmr_chain.return_value = mock_mmr_instance + + mock_bm25_instance = Mock() + mock_bm25_instance.retriever = Mock() + mock_bm25_chain.return_value = mock_bm25_instance + + mock_ensemble_instance = Mock() + mock_ensemble.return_value = mock_ensemble_instance + + mock_compression_instance = Mock() + mock_compression.return_value = mock_compression_instance + + with patch( + "langchain_google_community.vertex_rank.VertexAIRank" + ) as mock_vertex_rank: + mock_vertex_rank_instance = Mock() + mock_vertex_rank.return_value = mock_vertex_rank_instance + + chain.create_hybrid_retriever() + + mock_vertex_rank.assert_called_once_with( + project_id="test-project", + location_id="global", + ranking_config="default_ranking_config", + top_n=5, + ) + mock_compression.assert_called_once_with( + base_compressor=mock_vertex_rank_instance, + base_retriever=mock_ensemble_instance, + ) + + assert chain.retriever == mock_compression_instance + + @patch.dict( + "os.environ", + {"RERANKER_TYPE": "VERTEX_AI", "VERTEX_AI_PROJECT_ID": ""}, + ) + @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain") + @patch("src.chains.hybrid_retriever_chain.EnsembleRetriever") + def test_vertex_ai_rerank_raises_without_project_id( + self, + mock_ensemble, + mock_bm25_chain, + mock_mmr_chain, + mock_sim_chain, + ): + """Test that Vertex AI reranker raises error without project ID.""" + mock_vector_db = Mock() + mock_vector_db.processed_docs = [Mock(), Mock()] + + chain = HybridRetrieverChain( + vector_db=mock_vector_db, + contextual_rerank=True, + ) + + mock_sim_instance = Mock() + mock_sim_instance.retriever = Mock() + mock_sim_chain.return_value = mock_sim_instance + + mock_mmr_instance = Mock() + mock_mmr_instance.retriever = Mock() + mock_mmr_chain.return_value = mock_mmr_instance + + mock_bm25_instance = Mock() + mock_bm25_instance.retriever = Mock() + mock_bm25_chain.return_value = mock_bm25_instance + + mock_ensemble.return_value = Mock() + + with pytest.raises(ValueError, match="VERTEX_AI_PROJECT_ID must be set"): + chain.create_hybrid_retriever() + @patch("src.chains.hybrid_retriever_chain.os.path.isdir") @patch("src.chains.hybrid_retriever_chain.os.listdir") @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain") From 1dfd5e8a20d202ef697d1814cde92de88caca6de Mon Sep 17 00:00:00 2001 From: Harsh Kumar Date: Tue, 3 Mar 2026 19:25:47 +0530 Subject: [PATCH 2/2] Fix potential UnboundLocalError in ensemble retriever construction Initialize bm25_retriever to None and raise a clear ValueError if the ensemble retriever cannot be constructed due to missing sub-retrievers. Signed-off-by: Harsh Kumar Patwa Signed-off-by: Harsh Kumar --- backend/src/chains/hybrid_retriever_chain.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend/src/chains/hybrid_retriever_chain.py b/backend/src/chains/hybrid_retriever_chain.py index 7dd2ae7f..bd51e2a9 100644 --- a/backend/src/chains/hybrid_retriever_chain.py +++ b/backend/src/chains/hybrid_retriever_chain.py @@ -104,6 +104,7 @@ def create_hybrid_retriever(self) -> None: mmr_retriever = mmr_retriever_chain.retriever bm25_retriever_chain = BM25RetrieverChain() + bm25_retriever = None if self.vector_db is not None and self.vector_db.processed_docs: bm25_retriever_chain.create_bm25_retriever( @@ -120,6 +121,11 @@ def create_hybrid_retriever(self) -> None: retrievers=[similarity_retriever, mmr_retriever, bm25_retriever], weights=self.weights, ) + else: + raise ValueError( + "Failed to create ensemble retriever: one or more sub-retrievers " + "could not be initialized. Ensure vector_db has processed documents." + ) if self.contextual_rerank: reranker_type = os.getenv("RERANKER_TYPE", "HF").upper()