
本文旨在解决chainlit应用中,用户会话(`cl.user_session`)对象存取不当导致的常见错误。通过对比`set()`和`get()`方法的正确用法,详细解释了如何在`@cl.on_chat_start`和`@cl.on_message`生命周期钩子中正确管理langchain链对象,避免`usersession.set()`参数缺失及langchain输入变量查找失败等问题,并提供了完整的代码示例和最佳实践。
在构建基于Chainlit和Langchain的交互式AI应用时,正确管理用户会话中的状态至关重要。特别是对于需要跨消息传递复杂对象(如Langchain链)的场景,理解cl.user_session的set()和get()方法是避免常见错误的关键。
Chainlit提供了一个cl.user_session对象,用于在单个用户会话期间存储和检索数据。这对于在不同消息处理函数之间共享状态非常有用,例如在聊天开始时初始化一个Langchain链,并在后续的用户消息中复用该链。
cl.user_session主要提供两个核心方法:
用户在使用Chainlit时,常会遇到以下两种与cl.user_session相关的错误:
为了避免上述错误,正确的模式是在@cl.on_chat_start生命周期钩子中初始化并设置(set)链对象,然后在@cl.on_message钩子中获取(get)该对象。
示例代码中的问题点:
在提供的代码中,@cl.on_message函数内存在以下错误:
@cl.on_message
async def main(message):
chain = cl.user_session.set("chain") # 错误:这里应该使用 get
# ... 后续代码这里的问题是,开发者意图是获取之前存储的chain对象,却错误地调用了cl.user_session.set("chain")。由于set()方法需要一个value参数,但这里没有提供,因此会导致UserSession.set() missing 1 required positional argument: 'value'错误。更重要的是,即使不报错,cl.user_session.set("chain")(在没有value的情况下)也不会返回之前存储的链对象,而是可能返回None或一个不符合预期的值,从而导致后续chain.acall()调用失败,引发Langchain相关的验证错误。
修正方案:
将@cl.on_message中的错误行修改为:
chain = cl.user_session.get("chain") # 正确:从会话中获取已存储的链对象这样,chain变量将正确地指向在@cl.on_chat_start中初始化并存储的RetrievalQA链实例。
以下是经过修正的完整代码,展示了如何正确地在Chainlit中管理Langchain链对象:
from langchain.prompts import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import chainlit as cl
import os
# 确保DB_FAISS_PATH存在,这里假设它在当前脚本的同级目录
# 实际应用中,应确保此路径有效且包含预训练的FAISS索引
DB_FAISS_PATH = "vectorstores/db_faiss"
# 确保llama-2-7b-chat.ggmlv3.q8_0.bin模型文件可访问
# 通常放置在与脚本同级或指定路径下
LLM_MODEL_PATH = "llama-2-7b-chat.ggmlv3.q8_0.bin"
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, please just say that you don't know the answer, don't try to make up
an answer.
Context: {context}
Question: {question}
Only returns the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
"""
为QA检索设置PromptTemplate。
"""
prompt = PromptTemplate(template=custom_prompt_template,
input_variables=['context', 'question'])
return prompt
def load_llm():
"""
加载CTransformers模型。
"""
print("*** Start loading the LLM.")
llm = CTransformers(
model=LLM_MODEL_PATH,
model_type="llama",
max_new_tokens=512,
temperature=0.5
)
print("****** Finished loading the LLM.")
return llm
def retrieval_qa_chain(llm, prompt, db):
"""
创建并返回RetrievalQA链。
"""
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qa_chain
def qa_bot():
"""
初始化并返回一个完整的QA机器人链。
"""
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
model_kwargs={'device': 'cpu'})
# 检查FAISS向量库路径是否存在
if not os.path.exists(DB_FAISS_PATH):
raise FileNotFoundError(f"FAISS vector store not found at {DB_FAISS_PATH}. Please ensure it is created.")
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
print("*******FAISS.load_local() works well.")
# 检查LLM模型文件是否存在
if not os.path.exists(LLM_MODEL_PATH):
raise FileNotFoundError(f"LLM model not found at {LLM_MODEL_PATH}. Please ensure the model file is present.")
llm = load_llm()
print("****** LLM loading step works well.")
qa_prompt = set_custom_prompt()
qa = retrieval_qa_chain(llm, qa_prompt, db)
print("****** QA chain creation step works well.")
return qa
# chainlit ####
@cl.on_chat_start
async def start():
"""
聊天开始时初始化机器人并将其存储在用户会话中。
"""
chain = qa_bot()
msg = cl.Message(content="Starting the bot......")
await msg.send()
msg.content = "Hi, Welcome to the Medical Bot. What is your query?"
await msg.update()
# 将初始化的链对象存储到用户会话中
cl.user_session.set('chain', chain)
@cl.on_message
async def main(message):
"""
处理用户消息,从会话中获取链对象并执行查询。
"""
# 从用户会话中获取之前存储的链对象
chain = cl.user_session.get("chain")
if chain is None:
await cl.Message(content="Error: The bot chain was not initialized correctly. Please restart the chat.").send()
return
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True # 确保回调在答案生成后触发
# 假设message对象有一个content属性,或者可以直接作为查询使用
res = await chain.acall(message.content, callbacks=[cb]) # 传递message.content作为查询
answer = res["result"]
sources = res.get("source_documents", []) # 使用.get()以防没有源文档
if sources:
# 将源文档格式化为更易读的字符串
source_texts = [doc.page_content for doc in sources]
answer += f"\n\n**Sources:**\n" + "\n".join(source_texts)
else:
answer += f"\n\n**No Sources Found**"
await cl.Message(content=answer).send()
正确地使用cl.user_session.set()和cl.user_session.get()是开发稳定、高效Chainlit应用的基础。通过在聊天开始时初始化并存储资源,并在后续交互中按需检索,可以有效避免资源重复加载和因对象状态管理不当引发的运行时错误,从而提供流畅的用户体验。遵循本文提供的修正方案和最佳实践,将有助于您构建更健壮的AI聊天机器人。
以上就是解决Chainlit中用户会话链对象的正确存取方法的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号