Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,25 @@ cp .env.example .env
**Step 3**: For populating the `data` folder with OR/ORFS docs, OpenSTA docs and Yosys docs, run:

```bash
python build_docs.py
uv run python build_docs.py

# Alternatively, pull the latest docs
mkdir data
huggingface-cli download The-OpenROAD-Project/ORAssistant_RAG_Dataset --repo-type dataset --local-dir data/
# Note: If you encounter 429 errors, run 'uv run huggingface-cli login' first
# and/or use --max-workers 1 for sequential download
uv run huggingface-cli download The-OpenROAD-Project/ORAssistant_RAG_Dataset --repo-type dataset --local-dir data/ --max-workers 1
```

**Step 4**: To run the server, run:

```bash
python main.py
uv run python main.py
```

**Optionally**: To interact with the chatbot in your terminal, run:

```bash
python chatbot.py
uv run python chatbot.py
```

The backend will then be hosted at [http://0.0.0.0:8000](http://0.0.0.0:8000).
Expand Down
19 changes: 18 additions & 1 deletion backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,24 @@ There are 4 variables that needs to be set up

### Setting Up Huggingface User Access Token

To set up the `HF_TOKEN` variable in `.env` file , go through the following instructions:
If you encounter **429 Too Many Requests** errors while downloading the dataset, providing an `HF_TOKEN` is necessary.

1. Go to your [Hugging Face Settings -> Tokens](https://huggingface.co/settings/tokens).
2. Click **"New token"** and select the **"Read"** type.
3. Copy the token and add it to your `.env` file:
```bash
HF_TOKEN=your_token_here
```

Alternatively, you can login via CLI:
```bash
uv run huggingface-cli login
```

If you still encounter 429 errors after logging in, try downloading with a single worker (sequential download):
```bash
uv run huggingface-cli download The-OpenROAD-Project/ORAssistant_RAG_Dataset --repo-type dataset --local-dir data/ --max-workers 1
```

- Go the official website for [Huggingface](https://huggingface.co/) and either Login or Sign up.
- On the main page click on user access token
Expand Down
1 change: 1 addition & 0 deletions backend/src/api/models/response_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class UserInput(BaseModel):
query: str
list_sources: bool = False
list_context: bool = False
stream: bool = False
conversation_uuid: Optional[UUID] = None


Expand Down
26 changes: 23 additions & 3 deletions backend/src/api/routers/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,31 @@ def get_optional_db(db: Session = Depends(get_db)) -> Session | None:
return db if use_db else None


@router.post("/agent-retriever", response_model=ChatResponse)
@router.post("/agent-retriever", response_model=None)
async def get_agent_response(
user_input: UserInput, db: Session | None = Depends(get_optional_db)
) -> ChatResponse:
"""Processes a user query using the retriever agent, maintains conversation context, and returns the generated response along with relevant context sources and tools used."""
) -> ChatResponse | StreamingResponse:
"""Unified chat endpoint.

Pass ``stream=false`` (default) in the request body to receive a complete
JSON response once generation is done. Pass ``stream=true`` to receive
a ``text/event-stream`` response that yields tokens in real time as the
LLM produces them.

Example non-streaming body::

{"query": "How do I install OpenROAD?", "stream": false}

Example streaming body::

{"query": "How do I install OpenROAD?", "stream": true}
"""
if user_input.stream:
return StreamingResponse(
get_response_stream(user_input, db), media_type="text/event-stream"
)

# --- non-streaming path ---
user_question = user_input.query

conversation_uuid = user_input.conversation_uuid
Expand Down
27 changes: 25 additions & 2 deletions backend/src/api/routers/ui.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
from fastapi import APIRouter, Request, Response
from fastapi.responses import StreamingResponse
import httpx
import os

router = APIRouter(prefix="/ui", tags=["ui"])

BACKEND_ENDPOINT = os.getenv("BACKEND_ENDPOINT", "http://localhost:8000")
# Keep a streaming-capable client as well (no timeout for long SSE responses)
client = httpx.AsyncClient(base_url=BACKEND_ENDPOINT)
stream_client = httpx.AsyncClient(base_url=BACKEND_ENDPOINT, timeout=None)


@router.post("/chat")
async def proxy_chat(request: Request) -> Response:
async def proxy_chat(request: Request) -> Response | StreamingResponse:
"""Proxy to the backend chat endpoint.

Reads the ``stream`` field from the JSON body. When ``stream=true`` the
response is forwarded as a Server-Sent Events stream so tokens reach the
browser in real time. When ``stream=false`` (default) the full JSON
payload is returned once generation completes.
"""
data = await request.json()
# TODO: set this route dynamically
wants_stream: bool = bool(data.get("stream", False))

if wants_stream:
# Open a persistent streaming connection and forward chunks as-is.
async def _iter_sse():
async with stream_client.stream(
"POST", "/conversations/agent-retriever", json=data
) as resp:
async for chunk in resp.aiter_text():
yield chunk

return StreamingResponse(_iter_sse(), media_type="text/event-stream")

# Non-streaming: wait for the full response, then return it.
resp = await client.post("/conversations/agent-retriever", json=data)
return Response(
content=resp.content,
Expand Down
80 changes: 76 additions & 4 deletions backend/tests/test_api_conversations_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def sample_user_input():
query="What is OpenROAD?",
list_sources=False,
list_context=False,
stream=False,
conversation_uuid=uuid4(),
)

Expand Down Expand Up @@ -125,7 +126,7 @@ async def test_get_response_stream_creates_conversation(
from src.api.routers.conversations import get_response_stream

user_input = UserInput(
query="Test question", list_sources=False, conversation_uuid=None
query="Test question", list_sources=False, stream=False, conversation_uuid=None
)

async def mock_astream_events(*args, **kwargs):
Expand Down Expand Up @@ -171,7 +172,7 @@ async def test_get_response_stream_with_chat_history(
content="Previous answer",
)

user_input = UserInput(query="Follow-up question", conversation_uuid=conv_uuid)
user_input = UserInput(query="Follow-up question", stream=False, conversation_uuid=conv_uuid)

captured_inputs = []

Expand Down Expand Up @@ -386,7 +387,7 @@ async def test_get_response_stream_title_truncation(
from src.api.routers.conversations import get_response_stream

long_query = "A" * 150 # 150 characters
user_input = UserInput(query=long_query, conversation_uuid=None)
user_input = UserInput(query=long_query, stream=False, conversation_uuid=None)

async def mock_astream_events(*args, **kwargs):
yield {"event": "on_chat_model_end", "data": {}}
Expand All @@ -413,7 +414,7 @@ async def test_get_agent_response_streaming_endpoint(
from src.api.routers.conversations import get_agent_response_streaming
from starlette.responses import StreamingResponse

user_input = UserInput(query="Test", conversation_uuid=uuid4())
user_input = UserInput(query="Test", stream=False, conversation_uuid=uuid4())

async def mock_astream_events(*args, **kwargs):
yield {"event": "on_chat_model_end", "data": {}}
Expand Down Expand Up @@ -502,3 +503,74 @@ async def mock_astream_events(*args, **kwargs):
assert any("Sources:" in c for c in chunks)
sources_chunk = [c for c in chunks if "Sources:" in c][0]
assert sources_chunk.strip() == "Sources:"


class TestUnifiedEndpointStreamBranching:
"""Verify the unified /agent-retriever endpoint branches on stream field."""

@pytest.mark.asyncio
async def test_get_agent_response_stream_true(
self, db_session: Session, mock_retriever_graph
):
"""When stream=True the unified endpoint must return a StreamingResponse."""
from src.api.routers.conversations import get_agent_response
from starlette.responses import StreamingResponse

user_input = UserInput(
query="What is OpenROAD?",
stream=True,
conversation_uuid=uuid4(),
)

async def mock_astream_events(*args, **kwargs):
yield {"event": "on_chat_model_end", "data": {}}
yield {
"event": "on_chat_model_stream",
"data": {"chunk": AIMessageChunk(content="streamed token")},
}

mock_retriever_graph.astream_events = mock_astream_events

response = await get_agent_response(user_input, db_session)

assert isinstance(response, StreamingResponse), (
"Expected StreamingResponse when stream=True, got "
f"{type(response).__name__}"
)
assert response.media_type == "text/event-stream"

@pytest.mark.asyncio
async def test_get_agent_response_stream_false_returns_chat_response(
self, db_session: Session
):
"""When stream=False the unified endpoint must return a plain ChatResponse."""
from src.api.routers.conversations import get_agent_response
from src.api.models.response_model import ChatResponse

user_input = UserInput(
query="What is OpenROAD?",
stream=False,
conversation_uuid=uuid4(),
)

# Build a minimal graph output that parse_agent_output can handle
fake_output = [
{"classify": {"agent_type": ["rag_agent"]}},
{"retrieve_general": {"context": "ctx", "sources": [], "urls": [], "context_list": []}},
{
"rag_generate": {
"messages": ["OpenROAD is a chip design tool."]
}
},
]

with patch("src.api.routers.conversations.rg") as mock_rg:
mock_rg.graph = MagicMock()
mock_rg.graph.stream.return_value = iter(fake_output)
response = await get_agent_response(user_input, db_session)

assert isinstance(response, ChatResponse), (
"Expected ChatResponse when stream=False, got "
f"{type(response).__name__}"
)
assert response.response == "OpenROAD is a chip design tool."