| from copy import deepcopy |
| from typing import Dict, List, Any, Optional |
|
|
| import faiss |
|
|
| from langchain.docstore import InMemoryDocstore |
| from langchain.embeddings import OpenAIEmbeddings |
| from langchain.schema import Document |
| from langchain.vectorstores import Chroma, FAISS |
| from langchain.vectorstores.base import VectorStoreRetriever |
|
|
| from flows.base_flows import AtomicFlow |
| import hydra |
|
|
|
|
| class VectorStoreFlow(AtomicFlow): |
| REQUIRED_KEYS_CONFIG = ["type"] |
|
|
| vector_db: VectorStoreRetriever |
|
|
| def __init__(self, backend,vector_db, **kwargs): |
| super().__init__(**kwargs) |
| self.vector_db = vector_db |
|
|
|
|
| @classmethod |
| def _set_up_backend(cls, config): |
| kwargs = {} |
|
|
| kwargs["backend"] = \ |
| hydra.utils.instantiate(config['backend'], _convert_="partial") |
| |
| return kwargs |
| |
| |
| @classmethod |
| def _set_up_retriever(cls, api_information,config: Dict[str, Any]) -> Dict[str, Any]: |
| |
| embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key) |
| kwargs = {} |
|
|
| vs_type = config["type"] |
|
|
| if vs_type == "chroma": |
| vectorstore = Chroma(config["name"], embedding_function=embeddings) |
| elif vs_type == "faiss": |
| index = faiss.IndexFlatL2(config.get("embedding_size", 1536)) |
| vectorstore = FAISS( |
| embedding_function=embeddings.embed_query, |
| index=index, |
| docstore=InMemoryDocstore({}), |
| index_to_docstore_id={} |
| ) |
| else: |
| raise NotImplementedError(f"Vector store '{vs_type}' not implemented") |
|
|
| kwargs["vector_db"] = vectorstore.as_retriever(**config.get("retriever_config", {})) |
|
|
| return kwargs |
|
|
| @classmethod |
| def instantiate_from_config(cls, config: Dict[str, Any]): |
| flow_config = deepcopy(config) |
|
|
| kwargs = {"flow_config": flow_config} |
| |
| |
| kwargs.update(cls._set_up_backend(flow_config)) |
| api_information = kwargs["backend"].get_key() |
| |
| kwargs.update(cls._set_up_retriever(api_information,flow_config)) |
| |
| return cls(**kwargs) |
|
|
| @staticmethod |
| def package_documents(documents: List[str]) -> List[Document]: |
| |
| return [Document(page_content=doc, metadata={"": ""}) for doc in documents] |
|
|
| def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| response = {} |
|
|
| operation = input_data["operation"] |
| assert operation in ["write", "read"], f"Operation '{operation}' not supported" |
|
|
| content = input_data["content"] |
| if operation == "read": |
| assert isinstance(content, str), f"Content must be a string, got {type(content)}" |
| query = content |
| retrieved_documents = self.vector_db.get_relevant_documents(query) |
| response["retrieved"] = [doc.page_content for doc in retrieved_documents] |
| elif operation == "write": |
| if isinstance(content, str): |
| content = [content] |
| assert isinstance(content, list), f"Content must be a list of strings, got {type(content)}" |
| documents = content |
| documents = self.package_documents(documents) |
| self.vector_db.add_documents(documents) |
| response["retrieved"] = "" |
|
|
| return response |
|
|