构建流式检索增强生成系统:融合GraphQL、SSE与Service Worker的实时前端交互架构


标准的检索增强生成(RAG)交互模式存在一个根本性的延迟问题。用户提交查询,后端执行一个包含向量检索、文档重排、上下文构建和LLM推理的复杂工作流,整个过程可能耗时5到30秒。在这期间,前端界面通常显示一个静态的加载动画,用户体验极差。在真实项目中,这种不可预测的、漫长的等待是不可接受的。我们需要将这个黑盒过程透明化,为用户提供实时的、流式的反馈。

我们的目标是构建一个系统,它能够在后端处理RAG流程的每一步时,都将中间状态和最终结果实时推送到前端。用户将能看到“正在检索”、“正在生成”,并逐字看到最终答案的呈现。这要求我们彻底抛弃传统的请求-响应模式。

初步构想与技术选型决策

要实现这种流式交互,技术栈的选择至关重要。这不是简单地选择最新潮的技术,而是要找到最适合解决当前问题的、在生产环境中稳定可靠的组合。

  1. 流式通信协议:Server-Sent Events (SSE) vs. WebSockets
    WebSockets 提供双向通信,但对于我们的场景——服务器向客户端单向推送更新——则显得过于重型。SSE 基于标准HTTP,更轻量,易于实现,并且在需要断线重连时有内建支持。在云服务商的无服务器函数(如AWS Lambda)环境中,管理长时间的WebSocket连接状态比处理短暂的HTTP请求要复杂得多,成本也更高。SSE 成为了更务实的选择。

  2. API 接口协议:GraphQL vs. REST
    发起一个流式任务,可以用一个简单的REST POST 请求。但GraphQL的强类型Schema能为我们的 startRAGStream 操作提供更清晰的契约。更重要的是,在复杂前端应用中,GraphQL已经成为数据聚合的主流方案。我们可以将这个流式任务的启动无缝地集成到现有的GraphQL API中,而不是引入一个独立的REST端点。我们将设计一个GraphQL Mutation,它不直接返回结果,而是返回一个用于建立SSE连接的唯一streamId。这是一种关注点分离的体现:GraphQL负责发起命令,SSE负责传递事件流。

  3. 计算核心:NumPy 与 云服务商的无服务器函数
    RAG的核心是向量计算。虽然生产环境最终会演进到专用的向量数据库,但在项目初期或中等规模数据量(百万级向量)下,使用NumPy在内存中进行暴力余弦相似度搜索,部署在AWS Lambda或Google Cloud Functions这类无服务器函数上,是性价比极高的方案。它避免了维护数据库集群的运维开销,并且能够根据请求量自动伸缩。我们将把预计算好的向量索引文件(例如.npy格式)存储在对象存储(如S3)中,函数实例在冷启动时加载到内存。

  4. 前端增强:Service Workers
    为了提升二次查询的性能和提供离线能力,Service Worker是理想的选择。它不仅仅是缓存静态资源。我们可以设计一个智能缓存策略,拦截对GraphQL端点的请求。对于相似的查询,Service Worker可以返回缓存的检索结果,甚至直接返回最终的生成答案,从而完全绕过后端处理流程。这能极大地改善用户体验并降低云服务成本。

架构概览

整个工作流程被设计为一个解耦的、事件驱动的系统。

sequenceDiagram
    participant Client as 客户端 (React + Service Worker)
    participant APIGW as API网关 (GraphQL)
    participant Initiator as 启动函数 (Lambda)
    participant StreamHandler as 流处理函数 (Lambda)
    participant S3 as 对象存储 (向量索引)
    participant LLM as 大语言模型服务

    Client->>+APIGW: GraphQL Mutation: startRAGStream(query: "...")
    APIGW->>+Initiator: 调用启动函数
    Initiator-->>-APIGW: { streamId: "unique-id-123" }
    APIGW-->>-Client: 返回 streamId
    Client->>+StreamHandler: 建立SSE连接 (GET /stream?id=unique-id-123)
    Note over StreamHandler: 开始处理RAG工作流
    StreamHandler->>S3: 加载NumPy向量索引
    StreamHandler-->>Client: event: status, data: "retrieving"
    Note over StreamHandler: 使用NumPy执行向量检索
    StreamHandler-->>Client: event: chunks, data: [{...}, {...}]
    StreamHandler-->>Client: event: status, data: "generating"
    StreamHandler->>+LLM: 发起流式推理请求
    LLM-->>StreamHandler: stream token 1
    StreamHandler-->>Client: data: token_1
    LLM-->>StreamHandler: stream token 2
    StreamHandler-->>Client: data: token_2
    LLM-->>-StreamHandler: ...
    StreamHandler-->>Client: ...
    StreamHandler-->>-Client: event: end, data: "done"

后端实现:Python on AWS Lambda

我们将创建两个核心的Lambda函数。一个是处理GraphQL突变的Initiator,另一个是处理SSE连接的StreamHandler

1. 项目结构与依赖

/rag-streaming-backend
├── src
│   ├── graphql_handler.py    # Initiator Lambda
│   ├── sse_handler.py        # StreamHandler Lambda
│   ├── vector_search.py      # NumPy向量检索逻辑
│   └── llm_streaming.py      # LLM流式调用封装
├── requirements.txt
└── template.yaml             # Serverless Application Model (SAM) 配置文件

requirements.txt

mangum
fastapi
uvicorn
python-dotenv
boto3
numpy
sse-starlette
graphql-core
strawberry-graphql[fastapi]
# 假设使用AWS Bedrock
boto3
botocore

2. Initiator Lambda: 启动任务

这个函数的核心职责是验证输入、生成一个唯一的任务ID,并将其与查询一起暂存(例如存入DynamoDB或Redis),然后返回ID给客户端。我们使用FastAPI和Strawberry来构建GraphQL端点,并用Mangum将其适配到Lambda。

src/graphql_handler.py

import uuid
import json
import logging
import boto3
import os
from contextlib import asynccontextmanager

import strawberry
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from mangum import Mangum

# 日志配置
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 资源初始化
# 在真实项目中,这些配置应通过环境变量管理
DYNAMODB_TABLE_NAME = os.environ.get("DYNAMODB_TABLE_NAME", "RAGStreamTasks")
dynamodb = None

def get_dynamodb_resource():
    """惰性初始化DynamoDB客户端,便于测试和冷启动优化"""
    global dynamodb
    if dynamodb is None:
        dynamodb = boto3.resource('dynamodb')
    return dynamodb

@strawberry.type
class StartStreamResponse:
    streamId: str
    message: str

@strawberry.type
class Query:
    @strawberry.field
    def health_check(self) -> str:
        return "GraphQL service is operational."

@strawberry.type
class Mutation:
    @strawberry.mutation
    def start_rag_stream(self, query: str) -> StartStreamResponse:
        """
        接收用户查询,创建一个流任务,并返回streamId
        """
        if not query or len(query.strip()) < 5:
            # 生产级代码必须有严格的输入验证
            raise ValueError("Query is too short or invalid.")

        stream_id = str(uuid.uuid4())
        logger.info(f"Generated streamId: {stream_id} for query: '{query}'")

        try:
            table = get_dynamodb_resource().Table(DYNAMODB_TABLE_NAME)
            # TTL设置为10分钟,自动清理过期的任务
            ttl = int(time.time()) + 600 
            
            table.put_item(
                Item={
                    'streamId': stream_id,
                    'query': query,
                    'status': 'PENDING',
                    'ttl': ttl
                }
            )
            
            return StartStreamResponse(
                streamId=stream_id,
                message="Stream initiated successfully."
            )
        except Exception as e:
            logger.error(f"Failed to store stream task in DynamoDB: {e}")
            # 这里的错误处理很关键,不能让内部错误泄露给客户端
            raise Exception("Internal server error while initiating stream.")


@asynccontextmanager
async def lifespan(app: FastAPI):
    # 应用启动时执行的逻辑,例如预热连接池
    get_dynamodb_resource()
    yield
    # 应用关闭时执行的逻辑
    pass

schema = strawberry.Schema(query=Query, mutation=Mutation)
graphql_app = GraphQLRouter(schema)

app = FastAPI(lifespan=lifespan)
app.include_router(graphql_app, prefix="/graphql")

handler = Mangum(app)
  • 关键点:
    • 使用DynamoDB并设置TTL来暂存任务状态,这是一个无服务器架构中常见的、可靠的状态传递方式。
    • mangum 适配器让我们可以用熟悉的FastAPI框架来编写Lambda,极大提升了开发效率。
    • 包含了基础的日志和错误处理,这是生产代码的最低要求。

3. StreamHandler Lambda: 处理SSE流

这是系统的核心。它通过API Gateway的HTTP API触发,接收streamId,然后执行整个RAG流程,并通过SSE持续推送结果。

src/sse_handler.py

import asyncio
import json
import logging
import os
import time

import boto3
from fastapi import FastAPI, Request
from sse_starlette.sse import EventSourceResponse
from mangum import Mangum

from vector_search import VectorSearch
from llm_streaming import get_llm_response_stream

# 日志配置
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 全局变量以利用Lambda执行环境的复用
# 在第一次调用时加载,后续调用可直接使用
vector_search_instance = None

def initialize_vector_search():
    """
    冷启动时初始化向量搜索引擎。
    这是一个耗时操作,应放在处理器函数外部。
    """
    global vector_search_instance
    if vector_search_instance is None:
        logger.info("Initializing VectorSearch instance...")
        # 配置从环境变量读取
        bucket_name = os.environ.get("VECTOR_INDEX_BUCKET")
        index_key = os.environ.get("VECTOR_INDEX_KEY")
        metadata_key = os.environ.get("VECTOR_METADATA_KEY")
        
        if not all([bucket_name, index_key, metadata_key]):
            raise ValueError("Missing environment variables for vector search.")
            
        vector_search_instance = VectorSearch(bucket_name, index_key, metadata_key)
        logger.info("VectorSearch instance initialized.")

# 在模块加载时执行初始化
initialize_vector_search()

DYNAMODB_TABLE_NAME = os.environ.get("DYNAMODB_TABLE_NAME", "RAGStreamTasks")
dynamodb = boto3.resource('dynamodb')


async def rag_event_generator(request: Request):
    """
    SSE事件生成器,这是RAG工作流的核心实现
    """
    stream_id = request.query_params.get('id')
    if not stream_id:
        # 必须对输入进行校验
        yield {"event": "error", "data": "Missing streamId"}
        return

    try:
        table = dynamodb.Table(DYNAMODB_TABLE_NAME)
        response = table.get_item(Key={'streamId': stream_id})
        item = response.get('Item')
        if not item:
            yield {"event": "error", "data": "Invalid or expired streamId"}
            return
        query = item.get('query')
        logger.info(f"Starting RAG stream for id: {stream_id}, query: '{query}'")

        # 1. 状态更新:检索中
        yield {"event": "status", "data": "retrieving"}
        
        # 2. 执行向量检索
        start_time = time.perf_counter()
        retrieved_chunks = vector_search_instance.search(query, top_k=5)
        end_time = time.perf_counter()
        logger.info(f"Vector search completed in {end_time - start_time:.4f} seconds.")
        
        # 3. 推送检索到的块
        # 在真实项目中,内容可能很大,需要考虑分块或只发送摘要
        yield {"event": "chunks", "data": json.dumps([chunk.to_dict() for chunk in retrieved_chunks])}
        
        # 4. 状态更新:生成中
        yield {"event": "status", "data": "generating"}
        
        # 5. 构建上下文并调用LLM流式接口
        context = "\n".join([chunk.content for chunk in retrieved_chunks])
        prompt = f"Based on the following context:\n\n{context}\n\nAnswer the question: {query}"
        
        # 6. 流式推送LLM的token
        async for token in get_llm_response_stream(prompt):
            yield {"event": "message", "data": token}
            # 短暂休眠以避免网络拥塞,这在生产中可能需要更复杂的流控
            await asyncio.sleep(0.01)

    except Exception as e:
        logger.error(f"Error during RAG stream for id {stream_id}: {e}", exc_info=True)
        yield {"event": "error", "data": "An internal error occurred."}
    finally:
        # 7. 发送结束信号
        logger.info(f"RAG stream for id: {stream_id} finished.")
        yield {"event": "end", "data": "Stream completed"}


app = FastAPI()

@app.get("/stream")
async def stream_endpoint(request: Request):
    return EventSourceResponse(rag_event_generator(request))

handler = Mangum(app)

4. 向量检索模块 (vector_search.py)

这里展示了如何用NumPy实现一个简单的、可用于生产的向量搜索引擎。

import logging
import numpy as np
import boto3
from dataclasses import dataclass, asdict

logger = logging.getLogger()
logger.setLevel(logging.INFO)

@dataclass
class DocumentChunk:
    id: str
    content: str
    source: str
    
    def to_dict(self):
        return asdict(self)

class VectorSearch:
    def __init__(self, bucket: str, index_key: str, metadata_key: str):
        self.s3_client = boto3.client('s3')
        self.bucket = bucket
        
        # 加载向量索引和元数据是IO密集型和CPU密集型操作
        # 这是Lambda冷启动性能的主要瓶颈之一
        try:
            logger.info(f"Loading index from s3://{bucket}/{index_key}")
            index_obj = self.s3_client.get_object(Bucket=bucket, Key=index_key)
            self.vectors = np.load(index_obj['Body'])
            # L2归一化,为余弦相似度计算做准备
            self.vectors = self.vectors / np.linalg.norm(self.vectors, axis=1, keepdims=True)

            logger.info(f"Loading metadata from s3://{bucket}/{metadata_key}")
            metadata_obj = self.s3_client.get_object(Bucket=bucket, Key=metadata_key)
            self.metadata = json.load(metadata_obj['Body'])
            
            logger.info(f"Loaded {len(self.vectors)} vectors and {len(self.metadata)} metadata records.")
        except Exception as e:
            logger.error(f"Failed to load vector data from S3: {e}")
            # 失败后必须抛出异常,否则服务处于不可用状态
            raise

    def search(self, query_text: str, top_k: int = 5) -> list[DocumentChunk]:
        # 实际项目中,这里的 embedding model 也应被打包或作为层提供
        # 为简化,我们假设已经有一个 embedding 函数
        from sentence_transformers import SentenceTransformer
        model = SentenceTransformer('all-MiniLM-L6-v2') # 仅为示例
        
        query_vector = model.encode([query_text])[0]
        
        # 归一化查询向量
        query_vector = query_vector / np.linalg.norm(query_vector)

        # 计算余弦相似度 (等价于归一化向量的点积)
        scores = np.dot(self.vectors, query_vector)
        
        # 获取top_k的索引
        # argpartition比argsort更高效,因为它只保证第k个元素在正确位置
        top_k_indices = np.argpartition(scores, -top_k)[-top_k:]
        
        # 根据分数排序
        top_k_scores = scores[top_k_indices]
        sorted_indices = top_k_indices[np.argsort(top_k_scores)[::-1]]

        results = []
        for idx in sorted_indices:
            meta = self.metadata[idx]
            results.append(DocumentChunk(
                id=meta.get('id'),
                content=meta.get('content'),
                source=meta.get('source')
            ))
        return results

前端实现:React 与 Service Worker

前端的核心是消费SSE流并动态更新UI,同时利用Service Worker进行智能缓存。

1. UI 组件与SSE消费逻辑

// RAGStreamComponent.jsx
import React, { useState, useEffect, useCallback } from 'react';
import { useMutation, gql } from '@apollo/client';

const START_RAG_STREAM_MUTATION = gql`
  mutation StartRAGStream($query: String!) {
    startRagStream(query: $query) {
      streamId
      message
    }
  }
`;

export const RAGStreamComponent = () => {
  const [query, setQuery] = useState('');
  const [status, setStatus] = useState('idle');
  const [chunks, setChunks] = useState([]);
  const [llmResponse, setLlmResponse] = useState('');
  const [error, setError] = useState(null);

  const [startStream, { loading }] = useMutation(START_RAG_STREAM_MUTATION);

  const handleQuerySubmit = useCallback(async (e) => {
    e.preventDefault();
    if (!query.trim()) return;

    // 重置状态
    setStatus('initiating');
    setChunks([]);
    setLlmResponse('');
    setError(null);

    try {
      const { data } = await startStream({ variables: { query } });
      const { streamId } = data.startRagStream;
      
      if (!streamId) {
          throw new Error("Failed to get streamId from server.");
      }
      
      const eventSource = new EventSource(
        // 在生产环境中,这个URL应该是配置化的
        `/api/stream?id=${streamId}`
      );
      
      eventSource.addEventListener('status', (e) => {
        console.log('Status update:', e.data);
        setStatus(e.data);
      });

      eventSource.addEventListener('chunks', (e) => {
        const retrievedChunks = JSON.parse(e.data);
        console.log('Retrieved chunks:', retrievedChunks);
        setChunks(retrievedChunks);
      });

      // 默认的message事件用于接收LLM token
      eventSource.onmessage = (e) => {
        // 服务端可能会发送心跳包或其他非token信息,需要过滤
        if (e.data) {
             setLlmResponse(prev => prev + e.data);
        }
      };

      eventSource.addEventListener('end', () => {
        console.log('Stream ended.');
        setStatus('completed');
        eventSource.close();
      });

      eventSource.onerror = (err) => {
        console.error('EventSource failed:', err);
        setError('Connection to the stream failed.');
        setStatus('error');
        eventSource.close();
      };

    } catch (err) {
      console.error('Failed to start stream:', err);
      setError(err.message);
      setStatus('error');
    }
  }, [query, startStream]);

  return (
    <div>
      <form onSubmit={handleQuerySubmit}>
        <input 
          type="text" 
          value={query} 
          onChange={(e) => setQuery(e.target.value)}
          placeholder="Ask a question..."
          disabled={loading || (status !== 'idle' && status !== 'completed' && status !== 'error')}
        />
        <button type="submit" disabled={loading || (status !== 'idle' && status !== 'completed' && status !== 'error')}>
          {loading ? 'Initiating...' : 'Ask'}
        </button>
      </form>
      
      <div>
        <p><strong>Status:</strong> {status}</p>
        {chunks.length > 0 && (
          <div>
            <h4>Retrieved Documents:</h4>
            <ul>
              {chunks.map(chunk => <li key={chunk.id}>{chunk.source}</li>)}
            </ul>
          </div>
        )}
        {llmResponse && (
          <div>
            <h4>Answer:</h4>
            <p>{llmResponse}</p>
          </div>
        )}
        {error && <p style={{color: 'red'}}>Error: {error}</p>}
      </div>
    </div>
  );
};

2. Service Worker 缓存策略

这里的挑战在于,我们缓存的不是简单的GET请求,而是一个复杂操作的结果。我们将采用一种Stale-While-Revalidate的变种策略,拦截GraphQL请求。

public/sw.js

const CACHE_NAME = 'rag-cache-v1';
const GRAPHQL_ENDPOINT = '/graphql'; // 假设的GraphQL端点

self.addEventListener('install', event => {
  // 预缓存应用外壳
  event.waitUntil(
    caches.open(CACHE_NAME).then(cache => {
      // return cache.addAll(['/']);
    })
  );
});

self.addEventListener('fetch', event => {
  const { url, method } = event.request;

  if (url.endsWith(GRAPHQL_ENDPOINT) && method === 'POST') {
    event.respondWith(handleGraphQLRequest(event.request));
  } else {
    // 对其他请求使用标准网络优先策略
    event.respondWith(
      caches.match(event.request).then(response => {
        return response || fetch(event.request);
      })
    );
  }
});

async function handleGraphQLRequest(request) {
  const requestClone = request.clone();
  const body = await request.json();

  // 关键:只有特定类型的查询/突变才应该被缓存
  if (body.operationName === 'StartRAGStream') {
    const cacheKey = await createCacheKey(body);
    
    // 尝试从缓存中获取
    const cachedResponse = await caches.match(cacheKey);
    if (cachedResponse) {
      console.log('SW: Serving RAG stream from cache');
      // 注意:我们不能缓存SSE流本身,但可以缓存启动流的GraphQL响应
      // 一个更高级的策略是缓存最终的LLM答案,然后模拟一个流式响应
      // 这里为简化,我们只缓存启动流的响应,让客户端自己处理后续逻辑
      return cachedResponse;
    }
  }
  
  // 如果没有缓存,或者是不应缓存的操作,则走网络
  return fetch(requestClone);
}


async function createCacheKey(body) {
    // 基于操作名和查询变量创建一个稳定的key
    // 这是一个简化版本,生产中需要更鲁棒的哈希
    const query = body.variables.query;
    const keyString = `${body.operationName}-${query}`;
    
    // 使用 SubtleCrypto 来生成一个更可靠的哈希作为缓存键
    const encoder = new TextEncoder();
    const data = encoder.encode(keyString);
    const hashBuffer = await self.crypto.subtle.digest('SHA-256', data);
    const hashArray = Array.from(new Uint8Array(hashBuffer));
    const hashHex = hashArray.map(b => b.toString(16).padStart(2, '0')).join('');
    return new Request(`${GRAPHQL_ENDPOINT}?hash=${hashHex}`);
}
  • 单元测试思路:
    • 后端:
      • VectorSearch: 模拟S3的get_object,测试向量加载和归一化是否正确。提供一个固定的查询向量,断言返回的文档ID和顺序是否符合预期。
      • sse_handler: 使用HTTP测试客户端(如httpx)模拟SSE请求,传入有效的和无效的streamId,断言返回的事件流是否符合预期的顺序和内容。Mock掉VectorSearchLLM的调用。
      • graphql_handler: 发送模拟的GraphQL请求,测试start_rag_stream突变,断言DynamoDB的put_item是否被正确调用。
    • 前端:
      • 使用@testing-library/react来渲染组件,模拟用户输入和提交。Mock useMutation钩子和EventSource构造函数,断言UI状态(status, chunks, llmResponse)是否根据模拟的SSE事件正确更新。

局限性与未来迭代路径

这套架构解决了RAG交互的实时性问题,但在生产环境中仍有几个方面需要深化。

  1. 向量检索的瓶颈: NumPy暴力搜索在向量数量超过几百万时性能会急剧下降。下一步必然是迁移到专用的向量数据库(如Pinecone, Weaviate)或使用FAISS等库构建更高效的索引,并将其部署为独立的服务。
  2. SSE连接的健壮性: 当前实现没有处理网络中断后的流恢复。一个健壮的系统需要在客户端和服务器端实现状态同步,允许客户端从上一个收到的token或事件处恢复流,而不是从头开始。
  3. 成本考量: 长时间运行的StreamHandler Lambda可能比处理短请求的Lambda成本更高。对于非常耗时的LLM生成任务,可能需要考虑使用更适合长任务的计算服务,如AWS Fargate或EC2。同时需要对函数的内存和超时进行精细调优。
  4. Service Worker缓存的复杂性: 当前的缓存策略相对简单。一个更高级的系统可以实现语义缓存,理解查询的意图,为相似但不同措辞的查询返回相同缓存。此外,缓存失效策略也需要更精细的设计,以避免用户看到过时的信息。
  5. 安全性: SSE端点需要严格的认证和授权机制,确保只有发起任务的用户才能访问自己的数据流。这通常通过在启动任务时生成一个带签名的、有时效的token来实现,客户端在连接SSE时携带此token。

  目录