diff --git a/README.md b/README.md index 799536c3..3a25dab1 100644 --- a/README.md +++ b/README.md @@ -108,23 +108,25 @@ cp .env.example .env **Step 3**: For populating the `data` folder with OR/ORFS docs, OpenSTA docs and Yosys docs, run: ```bash -python build_docs.py +uv run python build_docs.py # Alternatively, pull the latest docs mkdir data -huggingface-cli download The-OpenROAD-Project/ORAssistant_RAG_Dataset --repo-type dataset --local-dir data/ +# Note: If you encounter 429 errors, run 'uv run huggingface-cli login' first +# and/or use --max-workers 1 for sequential download +uv run huggingface-cli download The-OpenROAD-Project/ORAssistant_RAG_Dataset --repo-type dataset --local-dir data/ --max-workers 1 ``` **Step 4**: To run the server, run: ```bash -python main.py +uv run python main.py ``` **Optionally**: To interact with the chatbot in your terminal, run: ```bash -python chatbot.py +uv run python chatbot.py ``` The backend will then be hosted at [http://0.0.0.0:8000](http://0.0.0.0:8000). diff --git a/backend/README.md b/backend/README.md index 9470f09c..88b3a0bd 100644 --- a/backend/README.md +++ b/backend/README.md @@ -99,7 +99,24 @@ There are 4 variables that needs to be set up ### Setting Up Huggingface User Access Token -To set up the `HF_TOKEN` variable in `.env` file , go through the following instructions: +If you encounter **429 Too Many Requests** errors while downloading the dataset, providing an `HF_TOKEN` is necessary. + +1. Go to your [Hugging Face Settings -> Tokens](https://huggingface.co/settings/tokens). +2. Click **"New token"** and select the **"Read"** type. +3. Copy the token and add it to your `.env` file: + ```bash + HF_TOKEN=your_token_here + ``` + +Alternatively, you can login via CLI: +```bash +uv run huggingface-cli login +``` + +If you still encounter 429 errors after logging in, try downloading with a single worker (sequential download): +```bash +uv run huggingface-cli download The-OpenROAD-Project/ORAssistant_RAG_Dataset --repo-type dataset --local-dir data/ --max-workers 1 +``` - Go the official website for [Huggingface](https://huggingface.co/) and either Login or Sign up. - On the main page click on user access token diff --git a/backend/src/api/models/response_model.py b/backend/src/api/models/response_model.py index 6d8c3818..d12dc77c 100644 --- a/backend/src/api/models/response_model.py +++ b/backend/src/api/models/response_model.py @@ -8,6 +8,7 @@ class UserInput(BaseModel): query: str list_sources: bool = False list_context: bool = False + stream: bool = False conversation_uuid: Optional[UUID] = None diff --git a/backend/src/api/routers/conversations.py b/backend/src/api/routers/conversations.py index f0450628..2cdebe72 100644 --- a/backend/src/api/routers/conversations.py +++ b/backend/src/api/routers/conversations.py @@ -245,11 +245,31 @@ def get_optional_db(db: Session = Depends(get_db)) -> Session | None: return db if use_db else None -@router.post("/agent-retriever", response_model=ChatResponse) +@router.post("/agent-retriever", response_model=None) async def get_agent_response( user_input: UserInput, db: Session | None = Depends(get_optional_db) -) -> ChatResponse: - """Processes a user query using the retriever agent, maintains conversation context, and returns the generated response along with relevant context sources and tools used.""" +) -> ChatResponse | StreamingResponse: + """Unified chat endpoint. + + Pass ``stream=false`` (default) in the request body to receive a complete + JSON response once generation is done. Pass ``stream=true`` to receive + a ``text/event-stream`` response that yields tokens in real time as the + LLM produces them. + + Example non-streaming body:: + + {"query": "How do I install OpenROAD?", "stream": false} + + Example streaming body:: + + {"query": "How do I install OpenROAD?", "stream": true} + """ + if user_input.stream: + return StreamingResponse( + get_response_stream(user_input, db), media_type="text/event-stream" + ) + + # --- non-streaming path --- user_question = user_input.query conversation_uuid = user_input.conversation_uuid diff --git a/backend/src/api/routers/ui.py b/backend/src/api/routers/ui.py index ecce0f1b..d829c725 100644 --- a/backend/src/api/routers/ui.py +++ b/backend/src/api/routers/ui.py @@ -1,17 +1,40 @@ from fastapi import APIRouter, Request, Response +from fastapi.responses import StreamingResponse import httpx import os router = APIRouter(prefix="/ui", tags=["ui"]) BACKEND_ENDPOINT = os.getenv("BACKEND_ENDPOINT", "http://localhost:8000") +# Keep a streaming-capable client as well (no timeout for long SSE responses) client = httpx.AsyncClient(base_url=BACKEND_ENDPOINT) +stream_client = httpx.AsyncClient(base_url=BACKEND_ENDPOINT, timeout=None) @router.post("/chat") -async def proxy_chat(request: Request) -> Response: +async def proxy_chat(request: Request) -> Response | StreamingResponse: + """Proxy to the backend chat endpoint. + + Reads the ``stream`` field from the JSON body. When ``stream=true`` the + response is forwarded as a Server-Sent Events stream so tokens reach the + browser in real time. When ``stream=false`` (default) the full JSON + payload is returned once generation completes. + """ data = await request.json() - # TODO: set this route dynamically + wants_stream: bool = bool(data.get("stream", False)) + + if wants_stream: + # Open a persistent streaming connection and forward chunks as-is. + async def _iter_sse(): + async with stream_client.stream( + "POST", "/conversations/agent-retriever", json=data + ) as resp: + async for chunk in resp.aiter_text(): + yield chunk + + return StreamingResponse(_iter_sse(), media_type="text/event-stream") + + # Non-streaming: wait for the full response, then return it. resp = await client.post("/conversations/agent-retriever", json=data) return Response( content=resp.content, diff --git a/backend/tests/test_api_conversations_streaming.py b/backend/tests/test_api_conversations_streaming.py index a3a921de..de0765ba 100644 --- a/backend/tests/test_api_conversations_streaming.py +++ b/backend/tests/test_api_conversations_streaming.py @@ -47,6 +47,7 @@ def sample_user_input(): query="What is OpenROAD?", list_sources=False, list_context=False, + stream=False, conversation_uuid=uuid4(), ) @@ -125,7 +126,7 @@ async def test_get_response_stream_creates_conversation( from src.api.routers.conversations import get_response_stream user_input = UserInput( - query="Test question", list_sources=False, conversation_uuid=None + query="Test question", list_sources=False, stream=False, conversation_uuid=None ) async def mock_astream_events(*args, **kwargs): @@ -171,7 +172,7 @@ async def test_get_response_stream_with_chat_history( content="Previous answer", ) - user_input = UserInput(query="Follow-up question", conversation_uuid=conv_uuid) + user_input = UserInput(query="Follow-up question", stream=False, conversation_uuid=conv_uuid) captured_inputs = [] @@ -386,7 +387,7 @@ async def test_get_response_stream_title_truncation( from src.api.routers.conversations import get_response_stream long_query = "A" * 150 # 150 characters - user_input = UserInput(query=long_query, conversation_uuid=None) + user_input = UserInput(query=long_query, stream=False, conversation_uuid=None) async def mock_astream_events(*args, **kwargs): yield {"event": "on_chat_model_end", "data": {}} @@ -413,7 +414,7 @@ async def test_get_agent_response_streaming_endpoint( from src.api.routers.conversations import get_agent_response_streaming from starlette.responses import StreamingResponse - user_input = UserInput(query="Test", conversation_uuid=uuid4()) + user_input = UserInput(query="Test", stream=False, conversation_uuid=uuid4()) async def mock_astream_events(*args, **kwargs): yield {"event": "on_chat_model_end", "data": {}} @@ -502,3 +503,74 @@ async def mock_astream_events(*args, **kwargs): assert any("Sources:" in c for c in chunks) sources_chunk = [c for c in chunks if "Sources:" in c][0] assert sources_chunk.strip() == "Sources:" + + +class TestUnifiedEndpointStreamBranching: + """Verify the unified /agent-retriever endpoint branches on stream field.""" + + @pytest.mark.asyncio + async def test_get_agent_response_stream_true( + self, db_session: Session, mock_retriever_graph + ): + """When stream=True the unified endpoint must return a StreamingResponse.""" + from src.api.routers.conversations import get_agent_response + from starlette.responses import StreamingResponse + + user_input = UserInput( + query="What is OpenROAD?", + stream=True, + conversation_uuid=uuid4(), + ) + + async def mock_astream_events(*args, **kwargs): + yield {"event": "on_chat_model_end", "data": {}} + yield { + "event": "on_chat_model_stream", + "data": {"chunk": AIMessageChunk(content="streamed token")}, + } + + mock_retriever_graph.astream_events = mock_astream_events + + response = await get_agent_response(user_input, db_session) + + assert isinstance(response, StreamingResponse), ( + "Expected StreamingResponse when stream=True, got " + f"{type(response).__name__}" + ) + assert response.media_type == "text/event-stream" + + @pytest.mark.asyncio + async def test_get_agent_response_stream_false_returns_chat_response( + self, db_session: Session + ): + """When stream=False the unified endpoint must return a plain ChatResponse.""" + from src.api.routers.conversations import get_agent_response + from src.api.models.response_model import ChatResponse + + user_input = UserInput( + query="What is OpenROAD?", + stream=False, + conversation_uuid=uuid4(), + ) + + # Build a minimal graph output that parse_agent_output can handle + fake_output = [ + {"classify": {"agent_type": ["rag_agent"]}}, + {"retrieve_general": {"context": "ctx", "sources": [], "urls": [], "context_list": []}}, + { + "rag_generate": { + "messages": ["OpenROAD is a chip design tool."] + } + }, + ] + + with patch("src.api.routers.conversations.rg") as mock_rg: + mock_rg.graph = MagicMock() + mock_rg.graph.stream.return_value = iter(fake_output) + response = await get_agent_response(user_input, db_session) + + assert isinstance(response, ChatResponse), ( + "Expected ChatResponse when stream=False, got " + f"{type(response).__name__}" + ) + assert response.response == "OpenROAD is a chip design tool."