使用Amazon SageMaker JumpStart、Llama 2和Amazon OpenSearch Serverless with Vector Engine构建一个针对金融服务的上下文聊天机器人

打造金融服务上下文聊天机器人:利用Amazon SageMaker JumpStart、Llama 2和Amazon OpenSearch Serverless with Vector Engine实现

金融服务(FinServ)行业具有与特定领域数据、数据安全、监管控制和行业合规标准相关的独特生成人工智能需求。此外,客户正在寻求选择最高性能和成本效益的机器学习(ML)模型,并且具备执行必要定制(微调)以适应其业务用例的能力。Amazon SageMaker JumpStart非常适用于FinServ客户的生成AI用例,因为它提供了所需的数据安全控制,并满足合规标准要求。

在本文中,我们展示了在SageMaker JumpStart中使用基于检索增强生成(RAG)方法和大型语言模型(LLM)的问答任务,使用了一个简单的金融领域用例。RAG是一种改进文本生成质量的框架,通过将LLM与信息检索(IR)系统相结合。LLM生成的文本,IR系统从知识库中检索相关信息。然后使用检索到的信息来增强LLM的输入,这可以帮助提高模型生成的文本的准确性和相关性。RAG已被证明对于各种文本生成任务(如问答和摘要)非常有效。它是改善文本生成模型的质量和准确性的有希望的方法。

使用SageMaker JumpStart的优势

借助SageMaker JumpStart,ML从业者可以从广泛的最先进模型中选择,用于内容编写、图像生成、代码生成、问答、文案编写、摘要、分类、信息检索等用例。ML从业者可以将基础模型部署到从网络隔离环境中的专用Amazon SageMaker实例,并使用SageMaker进行模型训练和部署的定制。

SageMaker JumpStart非常适用于FinServ客户的生成AI用例,因为它提供以下特点:

  • 定制能力 – SageMaker JumpStart提供示例笔记本和详细帖子,以逐步指导基础模型的领域适应。您可以按照这些资源进行微调、领域适应和基础模型的指导,或者构建基于RAG的应用程序。
  • 数据安全 – 确保推理有效负载数据的安全至关重要。使用SageMaker JumpStart,您可以在网络隔离中部署模型,并通过私有模型中心功能管理对选定模型的访问控制,符合个人安全要求。
  • 监管控制和合规性 – 符合HIPAA BAA、SOC123、PCI和HITRUST CSF等标准的合规性是SageMaker的核心特性,确保与金融行业严格的监管环境保持一致。
  • 模型选择 – SageMaker JumpStart提供一系列在业界公认的HELM基准测试中始终排名前列的最先进ML模型。其中包括但不限于Llama 2、Falcon 40B、AI21 J2 Ultra、AI21 Summarize、Hugging Face MiniLM和BGE模型。

在本文中,我们将使用在SageMaker JumpStart中提供的Llama 2基础模型和Hugging Face GPTJ-6B-FP16嵌入模型,利用RAG架构构建金融服务组织的上下文聊天机器人。我们还使用Vector Engine作为存储嵌入的Amazon OpenSearch Serverless(目前处于预览版)数据存储。

大型语言模型的局限性

LLM是基于大量非结构化数据进行训练的,并在一般文本生成方面表现出色。通过这种训练,LLM获得并存储事实知识。然而,现成的LLM存在以下局限性:

  • 它们的离线训练使它们不了解最新的信息。
  • 它们在主要是广义数据上的训练降低了它们在特定领域任务中的效果。例如,金融公司可能更喜欢其Q&A机器人从最新的内部文件中获取答案,以确保准确性并符合其业务规则。
  • 它们对嵌入式信息的依赖损害了可解释性。

使用LLM中的特定数据有三种常见的方法:

  • 嵌入数据到模型提示中,使其在生成输出时利用此上下文。这可以是零样本(没有示例),少样本(有限示例)或多样本(丰富的示例)。这种上下文提示使模型产生更细致的结果。
  • 使用提示和完成的成对数据对模型进行微调。
  • RAG(检索增强生成),它检索外部数据(非参数式)并将此数据集成到提示中,丰富上下文。

然而,第一种方法在上下文大小方面存在模型限制的问题,使得输入冗长的文档变得困难,并可能增加成本。微调方法虽然有效,但资源密集,特别是在外部数据不断演化时,导致延迟部署和增加成本。RAG结合LLMs则提供了解决前述限制的解决方案。

检索增强生成

RAG检索外部数据(非参数式)并将此数据集成到ML提示中,丰富上下文。Lewis等人在2020年引入了RAG模型,将其概念化为预训练序列到序列模型(参数化内存)和维基百科的密集向量索引(非参数化内存)的融合,通过神经检索器访问。

RAG的操作如下:

  • 数据源 – RAG可以从各种数据源获取数据,包括文档存储库、数据库或API。
  • 数据格式化 – 用户的查询和文档都被转换为适合相关性比较的格式。
  • 嵌入 – 为了便于比较,使用语言模型将查询和文档集合(或知识库)转换为数值嵌入。这些嵌入将文本概念以数值方式表达。
  • 相关性搜索 – 用户查询的嵌入将与文档集合的嵌入进行比较,在嵌入空间中通过相似性搜索识别相关文本。
  • 上下文丰富 – 将识别出的相关文本附加到用户的原始提示中,从而增强其上下文。
  • LLM处理 – 使用丰富的上下文,将提示输入LLM,由于包含相关的外部数据,生成相关和精确的输出。
  • 异步更新 – 为了确保参考文档保持最新,可以异步更新它们以及其嵌入表示。这确保未来的模型响应基于最新的信息,保证准确性。

简而言之,RAG提供了一种动态方法,可以将实时的相关信息注入LLMs中,确保生成精确和及时的输出。

下图显示了使用RAG和LLMs的概念流程:

解决方案概述

为了为金融服务应用程序创建一个有上下文的问答聊天机器人,需要以下步骤:

  1. 使用SageMaker JumpStart GPT-J-6B嵌入模型为上传到亚马逊简单存储服务(Amazon S3)上传目录中的每个PDF文档生成嵌入。
  2. 使用以下步骤识别相关文档:
    • 使用相同模型为用户的查询生成嵌入。
    • 使用带有向量引擎功能的OpenSearch Serverless在嵌入空间中搜索前K个最相关的文档索引。
    • 使用识别出的索引检索相应的文档。
  3. 将检索的文档与用户的提示和问题作为上下文组合,将其发送到SageMaker LLM进行响应生成。

我们使用了LangChain,这是一个流行的框架,用于组织这个过程。LangChain专为由LLMs驱动的应用程序设计,为各种LLMs提供了通用接口。它简化了多个LLMs的集成,确保在调用之间无缝保持状态。此外,它通过提供可定制的提示模板、全面的应用程序构建代理和专门的搜索和检索索引等功能,提高了开发效率。有关深入理解,请参考LangChain文档

先决条件

建立我们的上下文感知聊天机器人需要以下先决条件:

有关如何设置OpenSearch无服务器向量引擎的说明,请参阅 介绍面向Amazon OpenSearch无服务器的向量引擎(预览版)

要获得以下解决方案的详细步骤,请克隆 GitHub存储库 并参阅 Jupyter笔记本。

使用SageMaker JumpStart部署ML模型

要部署ML模型,请完成以下步骤:

  1. 从SageMaker JumpStart部署Llama 2 LLM:

    from sagemaker.jumpstart.model import JumpStartModelllm_model = JumpStartModel(model_id = "meta-textgeneration-llama-2-7b-f")llm_predictor = llm_model.deploy()llm_endpoint_name = llm_predictor.endpoint_name
  2. 部署GPT-J嵌入模型:

    embeddings_model = JumpStartModel(model_id = "huggingface-textembedding-gpt-j-6b-fp16")embed_predictor = embeddings_model.deploy()embeddings_endpoint_name = embed_predictor.endpoint_name

分块数据并创建文档嵌入对象

在本节中,您将把数据分块成较小的文档。分块是一种将大型文本拆分成较小块的技术。这是一个关键步骤,因为它优化了RAG模型对搜索查询的相关性,从而提高了聊天机器人的质量。分块大小取决于诸如文档类型和所使用模型的因素。选择了一个块大小为1600,因为这是段落的大约大小。随着模型的改进,它们的上下文窗口大小将增加,从而可以容纳更大的块大小。

有关完整解决方案,请参阅GitHub存储库中的 Jupyter笔记本

  1. 扩展LangChain SageMakerEndpointEmbeddings类,创建一个使用之前创建的gpt-j-6b-fp16 SageMaker端点的自定义嵌入函数:

    from langchain.embeddings import SagemakerEndpointEmbeddingsfrom langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandlerlogger = logging.getLogger(__name__)# extend the SagemakerEndpointEmbeddings class from langchain to provide a custom embedding functionclass SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):    def embed_documents(        self, texts: List[str], chunk_size: int = 1    ) → List[List[float]]:        """Compute doc embeddings using a SageMaker Inference Endpoint.         Args:            texts: The list of texts to embed.            chunk_size: The chunk size defines how many input texts will                be grouped together as request. If None, will use the                chunk size specified by the class.        Returns:            List of embeddings, one for each text.        """        results = []        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size        st = time.time()        for i in range(0, len(texts), _chunk_size):            response = self._embedding_func(texts[i : i + _chunk_size])            results.extend(response)        time_taken = time.time() - st        logger.info(            f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}"        )        print(            f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}"        )        return results# class for serializing/deserializing requests/responses to/from the embeddings modelclass ContentHandler(EmbeddingsContentHandler):    content_type = "application/json"    accepts = "application/json"     def transform_input(self, prompt: str, model_kwargs={}) → bytes:         input_str = json.dumps({"text_inputs": prompt, **model_kwargs})        return input_str.encode("utf-8")     def transform_output(self, output: bytes) → str:         response_json = json.loads(output.read().decode("utf-8"))        embeddings = response_json["embedding"]        if len(embeddings) == 1:            return [embeddings[0]]        return embeddingsdef create_sagemaker_embeddings_from_js_model(    embeddings_endpoint_name: str, aws_region: str) → SagemakerEndpointEmbeddingsJumpStart:     content_handler = ContentHandler()    embeddings = SagemakerEndpointEmbeddingsJumpStart(        endpoint_name=embeddings_endpoint_name,        region_name=aws_region,        content_handler=content_handler,    )    return embeddings
  2. 创建嵌入对象并批量创建文档嵌入:

    embeddings = create_sagemaker_embeddings_from_js_model(embeddings_endpoint_name, aws_region)
  3. 这些嵌入存储在向量引擎中,使用LangChain OpenSearchVectorSearch。您将在下一节中存储这些嵌入。将文档嵌入存储在OpenSearch无服务器中。现在您准备好遍历分块的文档,创建嵌入,并将这些嵌入存储在向量搜索集合中创建的OpenSearch无服务器向量索引中。请参见以下代码:

    docsearch = OpenSearchVectorSearch.from_texts(texts = [d.page_content for d in docs],embedding=embeddings,opensearch_url=[{'host': _aoss_host, 'port': 443}],http_auth=awsauth,timeout = 300,use_ssl = True,verify_certs = True,connection_class = RequestsHttpConnection,index_name=_aos_index)

关于文档的问题和回答

到目前为止,您已将大型文档分成了小片段,并创建了向量嵌入,并将它们存储在向量引擎中。现在,您可以根据这些文档数据回答问题。由于您已对数据进行了索引,因此可以进行语义搜索;这样,只有最相关的文档会通过提示传递给LLM。这样可以节省时间和金钱,只将相关的文档传递给LLM。有关使用文档链的更多详细信息,请参阅文档

完成以下步骤以使用文档回答问题:

  1. 使用LangChain的SageMaker LLM端点,您可以使用langchain.llms.sagemaker_endpoint.SagemakerEndpoint,该端点抽象了SageMaker LLM端点。根据所选择的LLM模型的content_type和accepts格式,您需要对ContentHandler中的代码进行调整。

    content_type = "application/json"accepts = "application/json"def transform_input(self, prompt: str, model_kwargs: dict) → bytes:        payload = {            "inputs": [                [                    {                        "role": "system",                        "content": prompt,                    },                    {"role": "user", "content": prompt},                ],            ],            "parameters": {                "max_new_tokens": 1000,                "top_p": 0.9,                "temperature": 0.6,            },        }        input_str = json.dumps(            payload,        )        return input_str.encode("utf-8")def transform_output(self, output: bytes) → str:    response_json = json.loads(output.read().decode("utf-8"))    content = response_json[0]["generation"]["content"]    return contentcontent_handler = ContentHandler()sm_jumpstart_llm=SagemakerEndpoint(        endpoint_name=llm_endpoint_name,        region_name=aws_region,        model_kwargs={"max_new_tokens": 300},        endpoint_kwargs={"CustomAttributes": "accept_eula=true"},        content_handler=content_handler,    )

现在,您已经准备好与财务文档互动了。

  1. 使用以下查询和提示模板来提问有关文档的问题:

    from langchain import PromptTemplate, SagemakerEndpointfrom langchain.llms.sagemaker_endpoint import LLMContentHandlerquery = "总结收益报告,报告是哪一年的"prompt_template = """只使用上下文来回答最后的问题。 {context} 问题:{question}回答:"""prompt = PromptTemplate(    template=prompt_template, input_variables=["context", "question"])  class ContentHandler(LLMContentHandler):    content_type = "application/json"    accepts = "application/json"    def transform_input(self, prompt: str, model_kwargs: dict) → bytes:        payload = {            "inputs": [                [                    {                        "role": "system",                        "content": prompt,                    },                    {"role": "user", "content": prompt},                ],            ],            "parameters": {                "max_new_tokens": 1000,                "top_p": 0.9,                "temperature": 0.6,            },        }        input_str = json.dumps(            payload,        )        return input_str.encode("utf-8")     def transform_output(self, output: bytes) → str:        response_json = json.loads(output.read().decode("utf-8"))        content = response_json[0]["generation"]["content"]        return contentcontent_handler = ContentHandler() chain = load_qa_chain(    llm=SagemakerEndpoint(        endpoint_name=llm_endpoint_name,        region_name=aws_region,        model_kwargs={"max_new_tokens": 300},        endpoint_kwargs={"CustomAttributes": "accept_eula=true"},        content_handler=content_handler,    ),    prompt=prompt,)sim_docs = docsearch.similarity_search(query, include_metadata=False)chain({"input_documents": sim_docs, "question": query}, return_only_outputs=True)

清理

为了避免未来产生费用,请删除在此笔记本中创建的SageMaker推断端点。您可以在SageMaker Studio笔记本中运行以下命令来删除:

# Delete LLMllm_predictor.delete_model()llm_predictor.delete_predictor(delete_endpoint_config=True)# Delete Embeddings Modelembed_predictor.delete_model()embed_predictor.delete_predictor(delete_endpoint_config=True)

如果您为此示例创建了一个OpenSearch Serverless集合,并且不再需要它,您可以通过OpenSearch Serverless控制台进行删除。

结论

在本文中,我们讨论了使用RAG作为一种方法,为LLMs提供领域特定上下文的方法。我们展示了如何使用SageMaker JumpStart基于RAG构建金融服务组织的上下文聊天机器人,使用Llama 2和OpenSearch Serverless作为矢量数据存储引擎。这种方法通过动态获取相关上下文来优化文本生成。我们很期待您在SageMaker JumpStart上使用这种基于RAG的策略,带入您的自定义数据并进行创新!