Google AI Edge RAG:高大上的 RAG Android 端上也能实现!

大家好,我是拭心。

随着大语言模型的兴起,RAG(Retrieval-Augmented Generation,检索增强生成)逐渐进入大家的视野,它可以解决 LLM 对特定业务领域知识不足的问题,生成更准确、更相关的回答。与微调相比,RAG 不需要对模型进行大量的重新训练,计算和存储需求较低,同时能够动态访问外部知识。一般用在企业知识库、智能客服等业务场景。

以前这种项目只能在服务端实现,离我们 Android 开发者很远,但如今,时代不一样了!

谷歌在 Google AI Edge 系列技术中,提供了 AI Edge RAG SDK,为端上实现检索增强生成 (RAG) 提供了强力的支持,我们 Android 也能做 RAG 了!

这篇文章我们来了解下这个 SDK。

一、RAG SDK API

1.1 RAG 的组成

在这里插入图片描述

如上图所示,一个 RAG 系统,主要包括这几步:

  1. 将知识库文本转换为嵌入向量
  2. 存储向量化的数据
  3. 将用户的输入转换为向量
  4. 从向量数据库检索与问题最相关的片段
  5. 将检索到的相关信息与用户问题结合后输入模型,生成最终答案

1.2 SDK API

AI Edge RAG SDK,提供了相关 API:

  1. 结构化和非结构化文本转换为嵌入向量:Embedder
public interface Embedder<T> {
    ListenableFuture<ImmutableList<Float>> getEmbeddings(EmbeddingRequest<T> request);

    ListenableFuture<ImmutableList<ImmutableList<Float>>> getBatchEmbeddings(EmbeddingRequest<T> request);
}
  1. 向量存储,存储从数据块派生的嵌入和元数据:VectorStore
public interface VectorStore<T> {
    void insert(VectorStoreRecord<T> record);

    List<VectorStoreRecord<T>> getNearestRecords(List<Float> queryEmbeddings, int topK, float minSimilarityScore);
}
  1. 语义检索器,用于在给定查询时检索相关数据块: SemanticMemory
public abstract class SemanticDataEntry<T> {
    public SemanticDataEntry() {
    }

    public abstract T getData();

    public abstract ImmutableMap<String, Object> getMetadata();

    public abstract Optional<T> getCustomEmbeddingData();

    public abstract Builder<T> toBuilder();

    public static <T> Builder<T> builder() {
        return (new AutoValue_SemanticDataEntry.Builder()).setMetadata(ImmutableMap.of());
    }

    public static <T> SemanticDataEntry<T> create(T data) {
        return create(data, ImmutableMap.of());
    }

    public static <T> SemanticDataEntry<T> create(T data, Map<String, Object> metadata) {
        return create(data, metadata, Optional.empty());
    }

    public static <T> SemanticDataEntry<T> create(T data, Map<String, Object> metadata, Optional<T> customEmbeddingData) {
        return builder().setData(data).setMetadata(metadata).setCustomEmbeddingData(customEmbeddingData).build();
    }
    ...
}
  1. 语言模型的封装接口,实现方式可以是本地或服务器模型: LanguageModel
public interface LanguageModel {
    ListenableFuture<LanguageModelResponse> generateResponse(LanguageModelRequest request, Executor executor);

    ListenableFuture<LanguageModelResponse> generateResponse(LanguageModelRequest request, Executor executor, @Nullable AsyncProgressListener<LanguageModelResponse> progressListener);
}

二、如何使用

除了 SDK,Google 还提供了一个 RAG 的 demo,我们通过这个 demo 来看看 API 如何使用。

2.1 demo 功能

demo 功能:通过 Edge RAG SDK 让 LLM 可以获取 assets/sample_context.txt 文件中的文本,从而能准确回答用户个人信息的问题。

assets/sample_context.txt 的内容:

<chunk_splitter>
My Name is Shixin Zhang
<chunk_splitter>
I am 32 years old
<chunk_splitter>
I am a software engineer
<chunk_splitter>
My personal website: https://gpt4oimageprompt.com
<chunk_splitter>
My Github link is https://github.com/shixinzhang
<chunk_splitter>
I like going to the movies
<chunk_splitter>
I don't like strawberries
<chunk_splitter>
I like apple
<chunk_splitter>
I am afraid of heights
<chunk_splitter>
I went skydiving once
<chunk_splitter>
Skydiving was not as scary as I thought
<chunk_splitter>
I like driving and learning about cars
<chunk_splitter>
I studied computer science in college
<chunk_splitter>
I like traveling
<chunk_splitter>
I have traveled to 18 different countries so far
<chunk_splitter>
I like rock climbing and I am now learning how to lead.
<chunk_splitter>
I once broke my ankle while climbing.
<chunk_splitter>
I like bouldering, and I'm scared of trad climbing.

可以看到,里面提供了 ShixinZhang(一位 32 岁程序员) 的一些个人信息。

demo 中使用到两个模型 Gemma3-1B 和 Gecko-110m-en,我们需要下载并 push 到手机上:

1.下载 Gemma3 并推送到 /data/local/tmp 下:

adb push /tmp/gemma3-1b-it-int4.task /data/local/tmp/gemma3-1b-it-int4.task

2.下载 Gecko-110m-en 和 sentencepiece.model (tokenizer) 并推送到 /data/local/tmp 下:

adb push sentencepiece.model /data/local/tmp/sentencepiece.model
adb push Gecko_256_quant.tflite /data/local/tmp/gecko.tflite

Gecko 是一种紧凑且通用的文本嵌入模型,其核心理念是将大型语言模型(LLM)中的知识提炼到检索器中,从而实现强大的检索性能.

sentencepiece.model 是 tokenizer。

在大语言模型中,tokenizer 负责把人类的自然语言(文字)转换成模型可以理解的「数字序列」。大模型本质上只理解数字向量。

模型训练时,tokenizer 决定了模型的词汇表(vocabulary)和输入粒度。因此训练好的大模型只适配特定的 tokenizer(因为模型学到的是 token 级别的模式,而不是原始文字)。

然后编译 app,启动后可以看到一个对话列表页面,稍等一段时间模型初始化(logcat 日志显示 Initialized)后即可输入内容。

MediaPipeLlmBackend     com.google.ai.edge.samples.rag       I  Initialized.

经过测试,输入 assets/sample_context.txt 中个人信息相关的问题,回答的准确率相当高👍:

在这里插入图片描述

不过,由于 demo 使用的嵌入模型 Gecko-110m-en 训练语料以英语为主,因此测试要用英文。

2.2 实现

了解了 demo 的功能后,接下来通过源码来了解下 RAG SDK 的 API。

2.2.1 RAG Pipeline 初始化

App 启动后,会初始化 RagPipeline,由它封装 RAG 相关功能:

class RagPipeline(private val application: Application) {
    // 配置语言模型
    private val mediaPipeLanguageModelOptions: LlmInferenceOptions =
        LlmInferenceOptions.builder()
            .setModelPath(GEMMA_MODEL_PATH)
            .setPreferredBackend(LlmInference.Backend.GPU)
            .setMaxTokens(1024)
            .build()

    // 配置嵌入模型
    private val embedder: Embedder<String> = if (COMPUTE_EMBEDDINGS_LOCALLY) {
        GeckoEmbeddingModel(
            GECKO_MODEL_PATH,
            Optional.of(TOKENIZER_MODEL_PATH),
            USE_GPU_FOR_EMBEDDINGS,
        )
    } else {
        GeminiEmbedder(
            GEMINI_EMBEDDING_MODEL,
            GEMINI_API_KEY
        )
    }

    // 创建 RAG 配置
    private val config = ChainConfig.create(
        MediaPipeLlmBackend(application.applicationContext, 
                           mediaPipeLanguageModelOptions,
                           mediaPipeLanguageModelSessionOptions),
        PromptBuilder(PROMPT_TEMPLATE),
        DefaultSemanticTextMemory(
            SqliteVectorStore(768), embedder
        )
    )

    // 初始化检索推理链
    private val retrievalAndInferenceChain = RetrievalAndInferenceChain(config)
}

上面的代码,主要做了这些事:

  1. 首先通过 mediaPipeLanguageModelOptions 配置了推理框架,使用的模型是 gemma3-1b-it-int4
  2. 然后初始化嵌入模型,使用的是 GeckoEmbeddingModel,它是 RAG SDK 提供的 Embedder 一个实现,提供了本地模型的嵌入提取功能
  3. 初始化MediaPipeLlmBackend,使用 MediaPipeLLM 实现推理
  4. 初始化SqliteVectorStore,使用 Sqlite 实现向量存储

2.2.2 知识库构建

启动后除了初始化这些配置,还会通过 ragPipeline.memorizeChunks 实现 assets/sample_context.txt 的文本分块和向量化:

fun memorizeChunks(context: Context, filename: String) {
    val reader = BufferedReader(InputStreamReader(context.assets.open(filename)))
    val texts = mutableListOf<String>()
    
    // 文本分块处理
    val sb = StringBuilder()
    generateSequence { reader.readLine() }
        .forEach { line ->
            if (line.startsWith(CHUNK_SEPARATOR)) {
                if (sb.isNotEmpty()) {
                    texts.add(sb.toString())
                }
                sb.clear()
                sb.append(line.removePrefix(CHUNK_SEPARATOR).trim())
            } else {
                sb.append(" ").append(line)
            }
        }
    
    if (sb.isNotEmpty()) {
        texts.add(sb.toString())
    }
    reader.close()
    
    // 存储到语义内存
    if (texts.isNotEmpty()) {
        memorize(texts)
    }
}

private fun memorize(facts: List<String>) {
    val future = config.semanticMemory.getOrNull()?.recordBatchedMemoryItems(
        ImmutableList.copyOf(facts)
    )
    future?.get()
}

recordBatchedMemoryItems 的作用:将文本数据批量向量化并存储。

步骤

  1. 将字符串列表转换为 SemanticDataEntry 列表(如果传入的是字符串列表)。
  2. 如果数据列表为空,直接返回 false
  3. 构建批量嵌入请求 embeddingRequest
  4. 使用嵌入模型批量获取嵌入向量。
  5. 遍历嵌入向量列表,构建 VectorStoreRecord 对象,并将其插入到向量存储中。
public ListenableFuture<Boolean> recordBatchedMemoryItems(ImmutableList<String> texts) {
//1. 将字符串列表转换为 `SemanticDataEntry` 列表(如果传入的是字符串列表)。
    return this.recordBatchedMemoryEntries((ImmutableList)texts.stream().map(SemanticDataEntry::create).collect(ImmutableList.toImmutableList()));
}

public ListenableFuture<Boolean> recordBatchedMemoryEntries(ImmutableList<SemanticDataEntry<String>> dataEntries) {
    if (dataEntries.isEmpty()) {
        return Futures.immediateFuture(false);
    } else {
        //3.构建批量嵌入请求 `embeddingRequest`。
        ImmutableList<EmbedData<String>> entries = (ImmutableList)dataEntries.stream().map((dataEntry) -> EmbedData.builder().setData((String)dataEntry.getCustomEmbeddingData().orElse((String)dataEntry.getData())).setTask(TaskType.RETRIEVAL_DOCUMENT).build()).collect(ImmutableList.toImmutableList());
        EmbeddingRequest<String> request = EmbeddingRequest.create(entries);
        
        //4.使用嵌入模型批量获取嵌入向量
        return Futures.transform(this.embeddingModel.getBatchEmbeddings(request), (embeddingsList) -> {
            if (embeddingsList.size() != dataEntries.size()) {
                throw new AssertionError(String.format("Embeddings list size is not equal to memory entries size, %d != %d", embeddingsList.size(), dataEntries.size()));
            } else {
            
                //5.遍历嵌入向量列表,构建 `VectorStoreRecord` 对象,并将其插入到向量存储中。
                for(int i = 0; i < embeddingsList.size(); ++i) {
                    VectorStoreRecord<String> record = VectorStoreRecord.builder().setData((String)((SemanticDataEntry)dataEntries.get(i)).getData()).setEmbeddings((ImmutableList)embeddingsList.get(i)).setMetadata(((SemanticDataEntry)dataEntries.get(i)).getMetadata()).build();
                    this.vectorStore.insert(record);
                }

                return true;
            }
        }, this.workerExecutor);
    }
}

2.2.3. 查询处理

suspend fun generateResponse(
    prompt: String,
    callback: AsyncProgressListener<LanguageModelResponse>?
): String = coroutineScope {
    // 创建检索请求
    val retrievalRequest = RetrievalRequest.create(
        prompt,
        RetrievalConfig.create(3, 0.0f, TaskType.QUESTION_ANSWERING)
    )
    
    // 执行检索增强生成
    retrievalAndInferenceChain.invoke(retrievalRequest, callback).await().text
}

用户输入后,最终由RetrievalAndInferenceChain 完成层层调用并生成结果:


/**
 * 结合检索和推理步骤的链。
 * 这个类实现了 Chain 接口,处理 RetrievalRequest 并生成 LanguageModelResponse。
 */
public final class RetrievalAndInferenceChain implements Chain<RetrievalRequest<String>, LanguageModelResponse> {
    /**
     * 链的配置,包括语义记忆、提示构建器和语言模型。
     */
    private final ChainConfig<String> config;
    
    /**
     * 用于运行异步任务的执行器。
     */
    private final Executor workerExecutor;

    /**
     * RetrievalAndInferenceChain 的构造函数。
     * @param config 链的配置。
     */
    public RetrievalAndInferenceChain(ChainConfig<String> config) {
        this.config = config;
        this.workerExecutor = Executors.newSingleThreadExecutor(); // 创建一个单线程的执行器用于异步任务
    }

    /**
     * 使用 RetrievalRequest 调用链。
     * @param retrievalRequest 要处理的检索请求。
     * @return 表示异步操作的 ListenableFuture。
     */
    public ListenableFuture<LanguageModelResponse> invoke(RetrievalRequest<String> retrievalRequest) {
        return this.invoke(retrievalRequest, (AsyncProgressListener)null); // 调用重载方法,传入 null 作为进度监听器
    }

    /**
     * 使用 RetrievalRequest 和可选的 AsyncProgressListener 调用链。
     * @param retrievalRequest 要处理的检索请求。
     * @param asyncProgressListener 可选的进度监听器。
     * @return 表示异步操作的 ListenableFuture。
     */
    public ListenableFuture<LanguageModelResponse> invoke(RetrievalRequest<String> retrievalRequest, @Nullable AsyncProgressListener<LanguageModelResponse> asyncProgressListener) {
        // 从配置中获取语义记忆
        SemanticMemory<String> memory = (SemanticMemory)this.config.getSemanticMemory().get();
        // 确保语义记忆不为 null
        Preconditions.checkNotNull(memory, "语义文本记忆不能为空");
        
        // 从语义记忆中检索结果
        ListenableFuture<RetrievalResponse<String>> responseFuture = memory.retrieveResults(retrievalRequest);
        
        // 将检索响应转换为语言模型响应
        return Futures.transformAsync(responseFuture, (response) -> {
            // 从检索到的实体中构建字符串
            StringBuilder memoryStringBuilder = new StringBuilder();
            response.getEntities().forEach((entity) -> memoryStringBuilder.append((String)entity.getData()).append("\n"));
            String memoryString = memoryStringBuilder.toString();
            
            // 使用提示构建器构建提示
            String prompt = ((PromptBuilder)this.config.getPromptBuilder().get()).buildPrompt(new Object[]{memoryString, retrievalRequest.getQuery()});
            
            // 创建带有构建提示的语言模型请求
            LanguageModelRequest languageModelRequest = LanguageModelRequest.builder().setPrompt(prompt).build();
            
            // 从语言模型生成响应
            return ((LanguageModel)this.config.getLanguageModel().get()).generateResponse(languageModelRequest, this.workerExecutor, asyncProgressListener);
        }, this.workerExecutor); // 使用工作执行器进行异步转换
    }
}

上面的代码主要做了这些事:

  1. 从配置中获取 SemanticMemory,并确保其不为 null
  2. 使用 SemanticMemoryretrieveResults 方法检索与请求相关的数据
  3. 将检索到的实体数据拼接成一个字符串。
  4. 使用 PromptBuilder 构建提示字符串,结合检索到的数据和原始查询。
  5. 创建 LanguageModelRequest,并调用语言模型的 generateResponse 方法生成最终响应。

最核心的就是 SemanticMemory#retrieveResults

/**
 * 根据检索请求从语义记忆中检索结果。
 * 
 * @param request 检索请求,包含查询字符串和检索配置。
 * @return 一个 ListenableFuture,表示异步检索操作的结果。
 */
public ListenableFuture<RetrievalResponse<String>> retrieveResults(RetrievalRequest<String> request) {
    // 创建 EmbedData.Builder 对象,设置查询数据
    EmbedData.Builder<String> embedDataBuilder = EmbedData.builder().setData((String)request.getQuery());
    
    // 根据请求的任务类型设置 EmbedData 的任务类型
    switch (request.getConfig().getTask()) {
        case TASK_UNSPECIFIED:
        case RETRIEVAL_QUERY:
            // 如果任务类型未指定或为检索查询,则设置为 RETRIEVAL_QUERY
            embedDataBuilder.setTask(TaskType.RETRIEVAL_QUERY);
            break;
        case QUESTION_ANSWERING:
            // 如果任务类型为问答,则设置为 QUESTION_ANSWERING
            embedDataBuilder.setTask(TaskType.QUESTION_ANSWERING);
    }

    // 创建嵌入请求,包含构建好的 EmbedData 对象
    EmbeddingRequest<String> embeddingRequest = EmbeddingRequest.create(ImmutableList.of(embedDataBuilder.build()));
    
    // 使用嵌入模型获取查询的嵌入向量,并异步转换为检索响应
    return Futures.transform(this.embeddingModel.getEmbeddings(embeddingRequest), (embeddings) -> {
        // 从向量存储中获取与查询嵌入向量最相似的记录
        List<VectorStoreRecord<String>> records = this.vectorStore.getNearestRecords(embeddings, request.getConfig().getTopK(), request.getConfig().getMinSimilarityScore());
        
        // 将记录转换为 RetrievalEntity 对象
        ImmutableList<RetrievalEntity<String>> entities = (ImmutableList)records.stream().map((record) -> 
            RetrievalEntity.builder()
                .setData((String)record.getData()) // 设置数据
                .setEmbeddings(record.getEmbeddings()) // 设置嵌入向量
                .setMetadata(record.getMetadata()) // 设置元数据
                .build()
        ).collect(ImmutableList.toImmutableList());
        
        // 创建并返回检索响应
        return RetrievalResponse.create(entities);
    }, this.workerExecutor); // 使用工作执行器进行异步转换
}

可以看到,和 2.2.2 知识库构建时一样,查询的时候也是通过this.embeddingModel.getEmbeddings(embeddingRequest) 获取文本对应的嵌入向量,拿到向量后,再去向量存储里查询最相似的记录。

总结

OK,以上就是 Google AI Edge RAG SDK 的相关介绍,通过官方的 demo,我们可以了解到 RAG SDK 的 API 及使用方式,核心就是:

  1. 通过嵌入模型实现文本到向量
  2. 使用向量数据库存储向量、文本和元数据
  3. 通过 MediaPipeLLM 实现推理

由于 AI Edge RAG 这个库设计很灵活,核心 API 都支持自定义,我们可以灵活地更改 RAG 流程的任何部分,从而支持自定义数据库、分块方法及检索函数等。

### NebulaGraph RAG介绍 NebulaGraph实现了基于知识图谱的检索增强生成(RAG),这有助于利用领域特定的知识来提高模型输出的质量和准确性[^1]。具体来说,当涉及到复杂查询或是需要高度专业知识的任务时,通过集成知识图谱中的结构化信息,能够有效减少AI系统的幻觉现象。 #### 基于知识图谱的RAG工作原理 在实现上,NebulaGraph发布了业界首个此类解决方案——即基于知识图谱的RAG功能。这一特性允许用户不仅限于传统的文本匹配方式来进行数据检索,而是可以通过关联实体之间的关系网络获取更精准的结果集[^2]。 对于开发者而言,这意味着可以在应用程序中轻松嵌入强大的语义理解和推理能力;而对于最终使用者,则可以获得更加贴合实际需求的信息推荐服务。 #### 使用案例分析 为了更好地理解如何应用这些技术,在官方文档中有两个具体的演示可供参考: - **Graph RAG vs Vector RAG**: 展示了两种不同类型的向量索引方法之间性能差异的同时,也突出了前者在处理具有明确模式的数据上的优势; - **Text2Cypher Visual Comparison**: 提供了一个直观的方式让用户看到自然语言查询是如何被转换成针对图数据库的有效命令序列,并执行相应操作的过程。 ```python from nebula3.gclient.net import ConnectionPool from nebula3.common.ttypes import Vertex, Edge # 连接到NebulaGraph实例并创建会话 connection_pool = ConnectionPool() session = connection_pool.get_session('root', 'password') # 执行Cypher风格的查询语句以展示RAG效果 result_set = session.execute( "MATCH (n)-[r]->(m) WHERE n.name='example' RETURN r LIMIT 10" ) for row in result_set.rows(): edge = row.columns()[0].get_edge() print(f"{edge.src_vertex_id} -> {edge.dst_vertex_id}") ``` 此代码片段展示了连接到NebulaGraph服务器并通过Cypher样式的查询语法进行简单的关系查找过程。虽然这段代码本身并不直接涉及RAG机制的具体细节,但它体现了构建在此基础上的应用程序可能采取的操作形式之一。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

拭心

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值