We show a really simple example of how to use the Retrieval-Augmented Generation (RAG) model with a Support Vector Machine (SVM) classifier to perform a text classification task (not using a vector database). The RAG model is a powerful language model that can retrieve relevant information from a knowledge base and use it to generate high-quality text.
This notebook is based on 1 and original LangChain notebook 2, of which, however, it represents a notable simplification in that it does not use a vector database (instead, it uses a simple SVM classifier to perform the classification task).
To run this notebook, a Python + Jupyter environment is needed together with a Groq Cloud Playground API key (you have to sign in using email or other accounts from other services).
All the modules needed are as follows (uncomment to install).
# !pip install langchain langchain_openai langchain_groq gpt4all scikit-learn
### Index
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.retrievers import SVMRetriever
from langchain_openai import OpenAIEmbeddings
from langchain_community.retrievers import SVMRetriever
# Add the additional URL to the list
urls = [
"https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/",
"https://research.ibm.com/blog/retrieval-augmented-generation-RAG",
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
# Load the web pages into a list of documents
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist] #
# Split the documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=300, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
### Generate
from langchain.prompts import PromptTemplate
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
# Prompt template for the question-answering task
prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Context: {context}
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["question", "document"],
)
# Initializes a ChatGroq language model with the specified temperature, model name, and API key.
llm = ChatGroq(temperature=0, model_name="llama3-8b-8192", groq_api_key='***INSERT API KEY HERE***')
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
rag_chain = prompt | llm | StrOutputParser()
# Load the GPT4All embeddings model
from langchain_community.embeddings import GPT4AllEmbeddings
embeddings = GPT4AllEmbeddings()
The purpose of the following code is to create a class called Retriever
that can retrieve the most relevant documents from a collection of documents based on a given query (question). It takes two inputs: a list of documents (docs
) and a set of pre-computed embeddings (embeddings
) for those documents.
The class initializes by creating a numpy array (embeds_np
) containing the embeddings for each document, normalized using L2 normalization. This normalization ensures that the embeddings have a consistent scale, which is important for the subsequent similarity calculations.
The main functionality of the class is provided by the query
method. This method takes a question (query) as input and an optional parameter k
, which specifies the number of top relevant documents to return (default is 3).
The query method first embeds the input question
and normalizes it using L2 normalization. It then concatenates the normalized question embedding with the normalized document embeddings (embeds_np
) into a single array (x
). It also creates a label array (y
) with the first element - corresponding to the query/question vector, indicating that it's the target label for the SVM training - labeled as positive (1) and the rest labeled as negative (0).
Next, the code trains a Support Vector Machine (SVM) classifier using the concatenated array x
and the label array y
. The SVM is a machine learning algorithm that can learn to separate the positive and negative examples in the data. In this case, it learns to separate the question embedding from the document embeddings, effectively learning the similarity between the question and each document.
After training the SVM, the code computes the similarity scores between the question and each document using the decision_function
method of the SVM. These similarity scores are then sorted in descending order, and the indices of the top k documents are retrieved.
Finally, the method returns a list of the top k
most relevant documents based on the sorted indices.
The key data transformation happening here is the normalization of the embeddings and the concatenation of the question embedding with the document embeddings. This allows the SVM to learn the similarity between the question and each document in a consistent and efficient manner.
import numpy as np
from sklearn import svm
class Retriever():
def __init__(self, docs, embeddings):
self.embeddings = embeddings
self.docs = docs
x = [doc_split.page_content for doc_split in docs]
embeds = embeddings.embed_documents(x)
embeds_np = np.array(embeds)
embeds_np = embeds_np / np.sqrt((embeds_np**2).sum(1, keepdims=True)) # L2 normalize the rows
self.embeds = embeds_np
# This method is responsible for retrieving the top k most relevant
# documents from a collection of documents based on a given query.
def query(self, question, k=3): # k is the number of top results to return
query = np.array(self.embeddings.embed_query(question))
query = query / np.sqrt((query**2).sum()) # L2 normalize the query
x = np.concatenate([[query], self.embeds])
y = np.zeros(len(x)) # initialize labels
y[0] = 1 # set the first element (the query) as positive
# SVM training
# https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=50000, tol=1e-5, C=1)
clf.fit(x, y) # train
# compute the similarity scores
similarities = clf.decision_function(x)
# sort in descending order and skip the first element (the query)
sorted_ix = np.argsort(-similarities)[1:]
res = []
for i in sorted_ix[:k]:
res.append(self.docs[i-1])
return res
retriever = Retriever(doc_splits, embeddings)
The purpose of this code is to retrieve relevant information from a knowledge base and generate a response to a given question using a RAG model.
The input to this code is a question string q
.
The output of this code is a generated response that lists the "types of adversarial attacks" based on the retrieved information and the given question.
Here's how the code achieves its purpose. The question string q
is assigned some value '...'. The retriever.query(q)
function is called, which retrieves relevant information on the given question q
. The retrieved information is stored in the variable d
. The rag_chain.invoke
function is called, which takes the retrieved information d
and the original question q
as input. This function allows the RAG model to generate a response based on the retrieved context and the question. The generated response is stored in the variable generation
and finally printed.
Test Cells
Try running the cells below to evaluate the model capabilities.
q = 'List the types of Adversarial Attacks'
d = retriever.query(q)
generation = rag_chain.invoke({"context": d, "question": q})
print(generation)
q = 'What is agent memory?'
d = retriever.query(q)
generation = rag_chain.invoke({"context": d, "question": q})
print(generation)
q='List all memory types'
d = retriever.query(q)
generation = rag_chain.invoke({"context": d, "question": q})
print(generation)
q='Explain "Tree of Thoughts"'
d = retriever.query(q)
generation = rag_chain.invoke({"context": d, "question": q})
print(generation)
q='What is RAG?'
d = retriever.query(q)
generation = rag_chain.invoke({"context": d, "question": q})
print(generation)
Useful links
LangChain notebook (link)
Groq Cloud Playground (link)
Support Vector Classification (link)