-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
105 lines (75 loc) · 2.78 KB
/
app.py
File metadata and controls
105 lines (75 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import os
from enum import Enum
from typing import List, Union
from dotenv import load_dotenv
from fastapi import FastAPI, Response, status
from pydantic import BaseModel
from embedding_models import (
OllamaEmbeddingModel,
OpenAIEmbeddingModel,
SentenceTransformerEmbeddingModel,
)
from reranker import get_scores
# from langchain.schema import Document
from splitter import get_split_documents_using_token_based
load_dotenv()
app = FastAPI()
class EmbeddingModelType(Enum):
"""
Embedding model types
"""
SENTENCE_TRANSFORMERS = 1
OLLAMA = 2
OPENAI = 3
MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "sentence-transformers/gtr-t5-large")
MODEL_TYPE = EmbeddingModelType(int(os.getenv("EMBEDDING_MODEL_TYPE", "1")))
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None)
embedding_model = None
if MODEL_TYPE == EmbeddingModelType.SENTENCE_TRANSFORMERS:
embedding_model = SentenceTransformerEmbeddingModel(model=MODEL_NAME)
elif MODEL_TYPE == EmbeddingModelType.OLLAMA:
embedding_model = OllamaEmbeddingModel(model=MODEL_NAME, base_url=OLLAMA_BASE_URL)
elif MODEL_TYPE == EmbeddingModelType.OPENAI:
embedding_model = OpenAIEmbeddingModel(model=MODEL_NAME)
class RequestSchemaForEmbeddings(BaseModel):
"""Request Schema"""
texts: Union[str, List[str]]
class RequestSchemaForTextSplitter(BaseModel):
"""Request Schema"""
model: str
documents: str
chunk_size: int
chunk_overlap: int
class RequestSchemaForReRankers(BaseModel):
"""Request Schema"""
query: str
documents: List[str]
@app.get("/")
async def home():
"""Returns a message"""
return Response(content="Embedding handler using models for texts", status_code=status.HTTP_200_OK)
@app.post("/get_embeddings")
async def generate_embeddings(item: RequestSchemaForEmbeddings):
"""
Generates the embedding vectors for the text/documents
based on different models
"""
if embedding_model:
if isinstance(item.texts, str):
return embedding_model.embed_query(text=item.texts)
elif isinstance(item.texts, list):
return embedding_model.embed_documents(texts=item.texts)
return []
@app.post("/split_docs_based_on_tokens")
async def get_split_docs(item: RequestSchemaForTextSplitter):
"""Splits the documents using the model tokenization method"""
docs = json.loads(item.documents)
return get_split_documents_using_token_based(
model_name=item.model, documents=docs, chunk_size=item.chunk_size, chunk_overlap=item.chunk_overlap
)
@app.post("/docs_reranking_scores")
async def get_reranked_docs(item: RequestSchemaForReRankers):
"""Get reranked documents"""
return get_scores(item.query, item.documents)