大家好,我是拭心。
随着大语言模型的兴起,RAG(Retrieval-Augmented Generation,检索增强生成)逐渐进入大家的视野,它可以解决 LLM 对特定业务领域知识不足的问题,生成更准确、更相关的回答。与微调相比,RAG 不需要对模型进行大量的重新训练,计算和存储需求较低,同时能够动态访问外部知识。一般用在企业知识库、智能客服等业务场景。
以前这种项目只能在服务端实现,离我们 Android 开发者很远,但如今,时代不一样了!
谷歌在 Google AI Edge 系列技术中,提供了 AI Edge RAG SDK,为端上实现检索增强生成 (RAG) 提供了强力的支持,我们 Android 也能做 RAG 了!
这篇文章我们来了解下这个 SDK。
一、RAG SDK API
1.1 RAG 的组成
如上图所示,一个 RAG 系统,主要包括这几步:
- 将知识库文本转换为嵌入向量
- 存储向量化的数据
- 将用户的输入转换为向量
- 从向量数据库检索与问题最相关的片段
- 将检索到的相关信息与用户问题结合后输入模型,生成最终答案
1.2 SDK API
AI Edge RAG SDK,提供了相关 API:
- 结构化和非结构化文本转换为嵌入向量:Embedder
public interface Embedder<T> {
ListenableFuture<ImmutableList<Float>> getEmbeddings(EmbeddingRequest<T> request);
ListenableFuture<ImmutableList<ImmutableList<Float>>> getBatchEmbeddings(EmbeddingRequest<T> request);
}
- 向量存储,存储从数据块派生的嵌入和元数据:VectorStore
public interface VectorStore<T> {
void insert(VectorStoreRecord<T> record);
List<VectorStoreRecord<T>> getNearestRecords(List<Float> queryEmbeddings, int topK, float minSimilarityScore);
}
- 语义检索器,用于在给定查询时检索相关数据块: SemanticMemory
public abstract class SemanticDataEntry<T> {
public SemanticDataEntry() {
}
public abstract T getData();
public abstract ImmutableMap<String, Object> getMetadata();
public abstract Optional<T> getCustomEmbeddingData();
public abstract Builder<T> toBuilder();
public static <T> Builder<T> builder() {
return (new AutoValue_SemanticDataEntry.Builder()).setMetadata(ImmutableMap.of());
}
public static <T> SemanticDataEntry<T> create(T data) {
return create(data, ImmutableMap.of());
}
public static <T> SemanticDataEntry<T> create(T data, Map<String, Object> metadata) {
return create(data, metadata, Optional.empty());
}
public static <T> SemanticDataEntry<T> create(T data, Map<String, Object> metadata, Optional<T> customEmbeddingData) {
return builder().setData(data).setMetadata(metadata).setCustomEmbeddingData(customEmbeddingData).build();
}
...
}
- 语言模型的封装接口,实现方式可以是本地或服务器模型: LanguageModel
public interface LanguageModel {
ListenableFuture<LanguageModelResponse> generateResponse(LanguageModelRequest request, Executor executor);
ListenableFuture<LanguageModelResponse> generateResponse(LanguageModelRequest request, Executor executor, @Nullable AsyncProgressListener<LanguageModelResponse> progressListener);
}
二、如何使用
除了 SDK,Google 还提供了一个 RAG 的 demo,我们通过这个 demo 来看看 API 如何使用。
2.1 demo 功能
demo 功能:通过 Edge RAG SDK 让 LLM 可以获取 assets/sample_context.txt 文件中的文本,从而能准确回答用户个人信息的问题。
assets/sample_context.txt
的内容:
<chunk_splitter>
My Name is Shixin Zhang
<chunk_splitter>
I am 32 years old
<chunk_splitter>
I am a software engineer
<chunk_splitter>
My personal website: https://gpt4oimageprompt.com
<chunk_splitter>
My Github link is https://github.com/shixinzhang
<chunk_splitter>
I like going to the movies
<chunk_splitter>
I don't like strawberries
<chunk_splitter>
I like apple
<chunk_splitter>
I am afraid of heights
<chunk_splitter>
I went skydiving once
<chunk_splitter>
Skydiving was not as scary as I thought
<chunk_splitter>
I like driving and learning about cars
<chunk_splitter>
I studied computer science in college
<chunk_splitter>
I like traveling
<chunk_splitter>
I have traveled to 18 different countries so far
<chunk_splitter>
I like rock climbing and I am now learning how to lead.
<chunk_splitter>
I once broke my ankle while climbing.
<chunk_splitter>
I like bouldering, and I'm scared of trad climbing.
可以看到,里面提供了 ShixinZhang(一位 32 岁程序员) 的一些个人信息。
demo 中使用到两个模型 Gemma3-1B 和 Gecko-110m-en,我们需要下载并 push 到手机上:
1.下载 Gemma3 并推送到 /data/local/tmp 下:
adb push /tmp/gemma3-1b-it-int4.task /data/local/tmp/gemma3-1b-it-int4.task
2.下载 Gecko-110m-en 和 sentencepiece.model (tokenizer) 并推送到 /data/local/tmp 下:
adb push sentencepiece.model /data/local/tmp/sentencepiece.model
adb push Gecko_256_quant.tflite /data/local/tmp/gecko.tflite
Gecko 是一种紧凑且通用的文本嵌入模型,其核心理念是将大型语言模型(LLM)中的知识提炼到检索器中,从而实现强大的检索性能.
sentencepiece.model 是 tokenizer。
在大语言模型中,tokenizer 负责把人类的自然语言(文字)转换成模型可以理解的「数字序列」。大模型本质上只理解数字向量。
模型训练时,tokenizer 决定了模型的词汇表(vocabulary)和输入粒度。因此训练好的大模型只适配特定的 tokenizer(因为模型学到的是 token 级别的模式,而不是原始文字)。
然后编译 app,启动后可以看到一个对话列表页面,稍等一段时间模型初始化(logcat 日志显示 Initialized)后即可输入内容。
MediaPipeLlmBackend com.google.ai.edge.samples.rag I Initialized.
经过测试,输入 assets/sample_context.txt
中个人信息相关的问题,回答的准确率相当高👍:
不过,由于 demo 使用的嵌入模型 Gecko-110m-en 训练语料以英语为主,因此测试要用英文。
2.2 实现
了解了 demo 的功能后,接下来通过源码来了解下 RAG SDK 的 API。
2.2.1 RAG Pipeline 初始化
App 启动后,会初始化 RagPipeline,由它封装 RAG 相关功能:
class RagPipeline(private val application: Application) {
// 配置语言模型
private val mediaPipeLanguageModelOptions: LlmInferenceOptions =
LlmInferenceOptions.builder()
.setModelPath(GEMMA_MODEL_PATH)
.setPreferredBackend(LlmInference.Backend.GPU)
.setMaxTokens(1024)
.build()
// 配置嵌入模型
private val embedder: Embedder<String> = if (COMPUTE_EMBEDDINGS_LOCALLY) {
GeckoEmbeddingModel(
GECKO_MODEL_PATH,
Optional.of(TOKENIZER_MODEL_PATH),
USE_GPU_FOR_EMBEDDINGS,
)
} else {
GeminiEmbedder(
GEMINI_EMBEDDING_MODEL,
GEMINI_API_KEY
)
}
// 创建 RAG 配置
private val config = ChainConfig.create(
MediaPipeLlmBackend(application.applicationContext,
mediaPipeLanguageModelOptions,
mediaPipeLanguageModelSessionOptions),
PromptBuilder(PROMPT_TEMPLATE),
DefaultSemanticTextMemory(
SqliteVectorStore(768), embedder
)
)
// 初始化检索推理链
private val retrievalAndInferenceChain = RetrievalAndInferenceChain(config)
}
上面的代码,主要做了这些事:
- 首先通过 mediaPipeLanguageModelOptions 配置了推理框架,使用的模型是
gemma3-1b-it-int4
- 然后初始化嵌入模型,使用的是
GeckoEmbeddingModel
,它是 RAG SDK 提供的Embedder
一个实现,提供了本地模型的嵌入提取功能 - 初始化
MediaPipeLlmBackend
,使用 MediaPipeLLM 实现推理 - 初始化
SqliteVectorStore
,使用 Sqlite 实现向量存储
2.2.2 知识库构建
启动后除了初始化这些配置,还会通过 ragPipeline.memorizeChunks
实现 assets/sample_context.txt
的文本分块和向量化:
fun memorizeChunks(context: Context, filename: String) {
val reader = BufferedReader(InputStreamReader(context.assets.open(filename)))
val texts = mutableListOf<String>()
// 文本分块处理
val sb = StringBuilder()
generateSequence { reader.readLine() }
.forEach { line ->
if (line.startsWith(CHUNK_SEPARATOR)) {
if (sb.isNotEmpty()) {
texts.add(sb.toString())
}
sb.clear()
sb.append(line.removePrefix(CHUNK_SEPARATOR).trim())
} else {
sb.append(" ").append(line)
}
}
if (sb.isNotEmpty()) {
texts.add(sb.toString())
}
reader.close()
// 存储到语义内存
if (texts.isNotEmpty()) {
memorize(texts)
}
}
private fun memorize(facts: List<String>) {
val future = config.semanticMemory.getOrNull()?.recordBatchedMemoryItems(
ImmutableList.copyOf(facts)
)
future?.get()
}
recordBatchedMemoryItems
的作用:将文本数据批量向量化并存储。
步骤:
- 将字符串列表转换为
SemanticDataEntry
列表(如果传入的是字符串列表)。 - 如果数据列表为空,直接返回
false
。 - 构建批量嵌入请求
embeddingRequest
。 - 使用嵌入模型批量获取嵌入向量。
- 遍历嵌入向量列表,构建
VectorStoreRecord
对象,并将其插入到向量存储中。
public ListenableFuture<Boolean> recordBatchedMemoryItems(ImmutableList<String> texts) {
//1. 将字符串列表转换为 `SemanticDataEntry` 列表(如果传入的是字符串列表)。
return this.recordBatchedMemoryEntries((ImmutableList)texts.stream().map(SemanticDataEntry::create).collect(ImmutableList.toImmutableList()));
}
public ListenableFuture<Boolean> recordBatchedMemoryEntries(ImmutableList<SemanticDataEntry<String>> dataEntries) {
if (dataEntries.isEmpty()) {
return Futures.immediateFuture(false);
} else {
//3.构建批量嵌入请求 `embeddingRequest`。
ImmutableList<EmbedData<String>> entries = (ImmutableList)dataEntries.stream().map((dataEntry) -> EmbedData.builder().setData((String)dataEntry.getCustomEmbeddingData().orElse((String)dataEntry.getData())).setTask(TaskType.RETRIEVAL_DOCUMENT).build()).collect(ImmutableList.toImmutableList());
EmbeddingRequest<String> request = EmbeddingRequest.create(entries);
//4.使用嵌入模型批量获取嵌入向量
return Futures.transform(this.embeddingModel.getBatchEmbeddings(request), (embeddingsList) -> {
if (embeddingsList.size() != dataEntries.size()) {
throw new AssertionError(String.format("Embeddings list size is not equal to memory entries size, %d != %d", embeddingsList.size(), dataEntries.size()));
} else {
//5.遍历嵌入向量列表,构建 `VectorStoreRecord` 对象,并将其插入到向量存储中。
for(int i = 0; i < embeddingsList.size(); ++i) {
VectorStoreRecord<String> record = VectorStoreRecord.builder().setData((String)((SemanticDataEntry)dataEntries.get(i)).getData()).setEmbeddings((ImmutableList)embeddingsList.get(i)).setMetadata(((SemanticDataEntry)dataEntries.get(i)).getMetadata()).build();
this.vectorStore.insert(record);
}
return true;
}
}, this.workerExecutor);
}
}
2.2.3. 查询处理
suspend fun generateResponse(
prompt: String,
callback: AsyncProgressListener<LanguageModelResponse>?
): String = coroutineScope {
// 创建检索请求
val retrievalRequest = RetrievalRequest.create(
prompt,
RetrievalConfig.create(3, 0.0f, TaskType.QUESTION_ANSWERING)
)
// 执行检索增强生成
retrievalAndInferenceChain.invoke(retrievalRequest, callback).await().text
}
用户输入后,最终由RetrievalAndInferenceChain
完成层层调用并生成结果:
/**
* 结合检索和推理步骤的链。
* 这个类实现了 Chain 接口,处理 RetrievalRequest 并生成 LanguageModelResponse。
*/
public final class RetrievalAndInferenceChain implements Chain<RetrievalRequest<String>, LanguageModelResponse> {
/**
* 链的配置,包括语义记忆、提示构建器和语言模型。
*/
private final ChainConfig<String> config;
/**
* 用于运行异步任务的执行器。
*/
private final Executor workerExecutor;
/**
* RetrievalAndInferenceChain 的构造函数。
* @param config 链的配置。
*/
public RetrievalAndInferenceChain(ChainConfig<String> config) {
this.config = config;
this.workerExecutor = Executors.newSingleThreadExecutor(); // 创建一个单线程的执行器用于异步任务
}
/**
* 使用 RetrievalRequest 调用链。
* @param retrievalRequest 要处理的检索请求。
* @return 表示异步操作的 ListenableFuture。
*/
public ListenableFuture<LanguageModelResponse> invoke(RetrievalRequest<String> retrievalRequest) {
return this.invoke(retrievalRequest, (AsyncProgressListener)null); // 调用重载方法,传入 null 作为进度监听器
}
/**
* 使用 RetrievalRequest 和可选的 AsyncProgressListener 调用链。
* @param retrievalRequest 要处理的检索请求。
* @param asyncProgressListener 可选的进度监听器。
* @return 表示异步操作的 ListenableFuture。
*/
public ListenableFuture<LanguageModelResponse> invoke(RetrievalRequest<String> retrievalRequest, @Nullable AsyncProgressListener<LanguageModelResponse> asyncProgressListener) {
// 从配置中获取语义记忆
SemanticMemory<String> memory = (SemanticMemory)this.config.getSemanticMemory().get();
// 确保语义记忆不为 null
Preconditions.checkNotNull(memory, "语义文本记忆不能为空");
// 从语义记忆中检索结果
ListenableFuture<RetrievalResponse<String>> responseFuture = memory.retrieveResults(retrievalRequest);
// 将检索响应转换为语言模型响应
return Futures.transformAsync(responseFuture, (response) -> {
// 从检索到的实体中构建字符串
StringBuilder memoryStringBuilder = new StringBuilder();
response.getEntities().forEach((entity) -> memoryStringBuilder.append((String)entity.getData()).append("\n"));
String memoryString = memoryStringBuilder.toString();
// 使用提示构建器构建提示
String prompt = ((PromptBuilder)this.config.getPromptBuilder().get()).buildPrompt(new Object[]{memoryString, retrievalRequest.getQuery()});
// 创建带有构建提示的语言模型请求
LanguageModelRequest languageModelRequest = LanguageModelRequest.builder().setPrompt(prompt).build();
// 从语言模型生成响应
return ((LanguageModel)this.config.getLanguageModel().get()).generateResponse(languageModelRequest, this.workerExecutor, asyncProgressListener);
}, this.workerExecutor); // 使用工作执行器进行异步转换
}
}
上面的代码主要做了这些事:
- 从配置中获取
SemanticMemory
,并确保其不为null
。 - 使用
SemanticMemory
的retrieveResults
方法检索与请求相关的数据。 - 将检索到的实体数据拼接成一个字符串。
- 使用
PromptBuilder
构建提示字符串,结合检索到的数据和原始查询。 - 创建
LanguageModelRequest
,并调用语言模型的generateResponse
方法生成最终响应。
最核心的就是 SemanticMemory#retrieveResults
:
/**
* 根据检索请求从语义记忆中检索结果。
*
* @param request 检索请求,包含查询字符串和检索配置。
* @return 一个 ListenableFuture,表示异步检索操作的结果。
*/
public ListenableFuture<RetrievalResponse<String>> retrieveResults(RetrievalRequest<String> request) {
// 创建 EmbedData.Builder 对象,设置查询数据
EmbedData.Builder<String> embedDataBuilder = EmbedData.builder().setData((String)request.getQuery());
// 根据请求的任务类型设置 EmbedData 的任务类型
switch (request.getConfig().getTask()) {
case TASK_UNSPECIFIED:
case RETRIEVAL_QUERY:
// 如果任务类型未指定或为检索查询,则设置为 RETRIEVAL_QUERY
embedDataBuilder.setTask(TaskType.RETRIEVAL_QUERY);
break;
case QUESTION_ANSWERING:
// 如果任务类型为问答,则设置为 QUESTION_ANSWERING
embedDataBuilder.setTask(TaskType.QUESTION_ANSWERING);
}
// 创建嵌入请求,包含构建好的 EmbedData 对象
EmbeddingRequest<String> embeddingRequest = EmbeddingRequest.create(ImmutableList.of(embedDataBuilder.build()));
// 使用嵌入模型获取查询的嵌入向量,并异步转换为检索响应
return Futures.transform(this.embeddingModel.getEmbeddings(embeddingRequest), (embeddings) -> {
// 从向量存储中获取与查询嵌入向量最相似的记录
List<VectorStoreRecord<String>> records = this.vectorStore.getNearestRecords(embeddings, request.getConfig().getTopK(), request.getConfig().getMinSimilarityScore());
// 将记录转换为 RetrievalEntity 对象
ImmutableList<RetrievalEntity<String>> entities = (ImmutableList)records.stream().map((record) ->
RetrievalEntity.builder()
.setData((String)record.getData()) // 设置数据
.setEmbeddings(record.getEmbeddings()) // 设置嵌入向量
.setMetadata(record.getMetadata()) // 设置元数据
.build()
).collect(ImmutableList.toImmutableList());
// 创建并返回检索响应
return RetrievalResponse.create(entities);
}, this.workerExecutor); // 使用工作执行器进行异步转换
}
可以看到,和 2.2.2 知识库构建时一样,查询的时候也是通过this.embeddingModel.getEmbeddings(embeddingRequest)
获取文本对应的嵌入向量,拿到向量后,再去向量存储里查询最相似的记录。
总结
OK,以上就是 Google AI Edge RAG SDK 的相关介绍,通过官方的 demo,我们可以了解到 RAG SDK 的 API 及使用方式,核心就是:
- 通过嵌入模型实现文本到向量
- 使用向量数据库存储向量、文本和元数据
- 通过 MediaPipeLLM 实现推理
由于 AI Edge RAG 这个库设计很灵活,核心 API 都支持自定义,我们可以灵活地更改 RAG 流程的任何部分,从而支持自定义数据库、分块方法及检索函数等。