RAG, prompt and memory with Mixtral

#201
by edoyen - opened

Helle everyone, I'm trying to create a chat bot that answers specific questions about our software. I'm using the RAG method. I've created a prompt that corresponds to my expectations and I call the mixtral8x7B model with the HuggingFaceHub function. Then I launch a chain with the ConversationalRetrievalChain.from_llm function. When the response is displayed, it shows me my entire prompt and context, whereas I only want to see the response. What's more, the memory keeps the whole display, whereas I'd like to keep only the LLM result. Do you have any idea how to do this?

Code :

class ChatBot:

loader = TextLoader(file_path)
documents = loader.load()

text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20, separator="\n# ")
docs = text_splitter.split_documents(documents)

embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.from_documents(docs, embedding=embeddings)

# Define the repo ID and connect to Mixtral model on Huggingface
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
mistral_llm = HuggingFaceHub(
    repo_id=repo_id,
    model_kwargs={"temperature": 0.8, "top_k": 50},
    huggingfacehub_api_token=HUGGINGFACE_API_TOKEN,
)

memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer', input_key='question')

template = """
You are a Blabla expert. Users will ask questions about Blabla, its services and products. You need to answer them clearly and concisely.

Use the context and the chat history to answer the question below, returns only the answer to this question.
If you don't know the answer, just say you don't know.

Chat history : {chat_history}

Context: {context}
Question: {question}
Answer: """



prompt = PromptTemplate(template=template, input_variables=["chat_history","context", "question"])

rag_chain = ConversationalRetrievalChain.from_llm(
    llm=mistral_llm,
    chain_type="stuff",
    retriever=vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3}),
    memory=memory,
    return_source_documents=True,
    combine_docs_chain_kwargs={"prompt": prompt}
)

Outside ChatBot() class

bot = ChatBot()
user_input = input("Ask me anything: ")
result = bot.rag_chain({"question" : user_input})
print(result["answer"])

i may help you in that, this is my thesis code that i'm currently working on, we have similar code, but mine has no memory because the buffer memory isn't working with mine :(
But the whole code and chatbot works!

CHATBOT

HF_TOKEN

HF_TOKEN = 'Your_token'
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF_TOKEN

Load PDF documents from the upload folder

loader = PyPDFDirectoryLoader(UPLOAD_FOLDER)
loader.requests_per_second = 1
docs = loader.load()

Chunking - Text Splitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
chunks = text_splitter.split_documents(docs)

Embedding Model

embeddings = HuggingFaceEmbeddings()
vectorstore = Chroma.from_documents(chunks, embeddings, persist_directory="db")

Retrieval

query = "What is Clustering?"
search = vectorstore.similarity_search(query)

Retriever

retriever = vectorstore.as_retriever(
search_type="mmr", # similarity
search_kwargs={'k': 4}
)

LLM-Open Source

llm = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
model_kwargs={"temperature": 0.5, "max_length": 64, "max_new_tokens": 512}
)

RAG RetrievalQA chain

qa = RetrievalQA.from_chain_type(llm=llm, chain_type="refine", retriever=retriever)

Prompt Template

DEFAULT_SYSTEM_PROMPT = """
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
""".strip()

def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"""
{prompt}
""".strip()

Chain

SYSTEM_PROMPT = "Use the following pieces of context and history to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."

template = generate_prompt(
"""
context: {context}
question: {query}
answer:
""",
system_prompt=SYSTEM_PROMPT,
)

prompt = ChatPromptTemplate.from_template(template)

retrieval_chain = RetrievalQA.from_chain_type(llm,
chain_type='stuff',
retriever=vectorstore.as_retriever(),
chain_type_kwargs={
"prompt": prompt,
})

rag_chain = (
{"context": retriever, "query": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)

Function to retrieve conversation history

def get_chat_history():
return list(chat_history_collection.find())

@app .route('/chat', methods=['GET', 'POST'])
def chat():
if request.method == 'GET':
chat_history = get_chat_history()
return render_template('chatbot.html', chat_history=chat_history)
elif request.method == 'POST':
user_input = request.json['user_input']
chat_history_collection.insert_one({'sender': 'user', 'message': user_input})
response = rag_chain.invoke(user_input)
bot_response = response.split('answer: ')[1].strip()
chat_history_collection.insert_one({'sender': 'bot', 'message': bot_response})

    # Retrieve the last document if any exists
    last_chat_document = chat_messages_collection.find_one(sort=[('_id', -1)])
    
    # Check if a document exists and extract the existing chat history
    if last_chat_document:
        chat_history = last_chat_document.get('chat_history', [])
    else:
        chat_history = []
    
    # Append the current user input and bot response to the chat history
    chat_history.append({'user_input': user_input, 'bot_response': bot_response})
    
    # Update the chat history in the database
    if last_chat_document:
        chat_messages_collection.update_one({'_id': last_chat_document['_id']}, {'$set': {'chat_history': chat_history}})
    else:
        chat_messages_collection.insert_one({'chat_history': chat_history})
    
    # Format the chat history
    formatted_chat_history = []
    for entry in chat_history:
        formatted_entry = {
            'user_input': entry['user_input'],
            'bot_response': entry['bot_response']
        }
        formatted_chat_history.append(formatted_entry)
    return response

i'm using flask and html.
I've been trying to integrate and get the buffer memory code to work, but i can't figure it out.

Check my code above, you'll get the response you want! :)

if you're using colab or jupyter, i could give you the notebook code.
The input and response should work for you, i don't know about the buffer memory though, but if the code works for you, let me know!

I'm facing the same issue as OP, although I am using Langchian to stack multiple chains for generating summaries. I'm not using RAG, so context doesnt get returned but my few shot examples are returned instead.....

if you're using colab or jupyter, i could give you the notebook code.
The input and response should work for you, i don't know about the buffer memory though, but if the code works for you, let me know!

Do you mind sharing your notebook with me?

Finally I found the answer to my problem by using the HuggingFaceEnpoint function instead of HuggingFaceHub if that helps.

  repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
  mistral_llm = HuggingFaceEndpoint(
           repo_id=repo_id,
           temperature=0.5,
           top_k=50,
           huggingfacehub_api_token="HUGGINGFACE_API_TOKEN"
   )

did the buffer memory work as well?

I found a workaround, for my case it was having to add a return_full_text param and have it set to False.

    pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            device="cuda",
            return_full_text=False
        )
  self.llm = HuggingFacePipeline(pipeline=pipe)

i can't find the correct libraries for HuggingFaceEndpoint

Sign up or log in to comment