From 4967c3f906b32184aa2e617fea034900bb564aa2 Mon Sep 17 00:00:00 2001 From: jiahao.he@vtradex.com <794629435@qq.com> Date: 星期日, 16 三月 2025 20:01:34 +0800 Subject: [PATCH] 本地向量化 --- ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java | 92 +++++++++++ ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java | 15 + ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java | 20 ++ ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java | 198 ++++++++++++++++++++++++ ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java | 38 ++++ ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java | 72 +++++++- ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java | 25 +++ 7 files changed, 445 insertions(+), 15 deletions(-) diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java new file mode 100644 index 0000000..4ca71ba --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java @@ -0,0 +1,38 @@ +package org.ruoyi.common.chat.entity.models; + +import lombok.Data; + +import java.util.List; + +/** + * @program: RUOYIAI + * @ClassName LocalModelsSearchRequest + * @description: + * @author: hejh + * @create: 2025-03-15 17:22 + * @Version 1.0 + **/ +@Data +public class LocalModelsSearchRequest { + + private List<String> text; + private String model_name; + private String delimiter; + private int k; + private int block_size; + private int overlap_chars; + + // 鏋勯�犲嚱鏁般�丟etter 鍜� Setter + public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) { + this.text = text; + this.model_name = model_name; + this.delimiter = delimiter; + this.k = k; + this.block_size = block_size; + this.overlap_chars = overlap_chars; + } + + +} + + diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java new file mode 100644 index 0000000..12025d5 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java @@ -0,0 +1,20 @@ +package org.ruoyi.common.chat.entity.models; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class LocalModelsSearchResponse { + @JsonProperty("topKEmbeddings") + + private List<List<List<Double>>> topKEmbeddings; // 澶勭悊涓夊眰宓屽鏁扮粍 + + // 榛樿鏋勯�犲嚱鏁� + public LocalModelsSearchResponse() {} + + + +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java new file mode 100644 index 0000000..606a7c2 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java @@ -0,0 +1,198 @@ +package org.ruoyi.common.chat.localModels; + +import io.micrometer.common.util.StringUtils; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; +import org.springframework.stereotype.Service; +import retrofit2.Call; +import retrofit2.Callback; +import retrofit2.Response; +import retrofit2.Retrofit; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.util.List; +import java.util.concurrent.CountDownLatch; + +@Slf4j +@Service +public class LocalModelsofitClient { + private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 鏈嶅姟鐨� URL + private static Retrofit retrofit = null; + + // 鑾峰彇 Retrofit 瀹炰緥 + public static Retrofit getRetrofitInstance() { + if (retrofit == null) { + OkHttpClient client = new OkHttpClient.Builder() + .build(); + + retrofit = new Retrofit.Builder() + .baseUrl(BASE_URL) + .client(client) + .addConverterFactory(JacksonConverterFactory.create()) // 浣跨敤 Jackson 澶勭悊 JSON 杞崲 + .build(); + } + return retrofit; + } + + /** + * 鍚� Flask 鏈嶅姟鍙戦�佹枃鏈悜閲忓寲璇锋眰 + * + * @param queries 鏌ヨ鏂囨湰鍒楄〃 + * @param modelName 妯″瀷鍚嶇О + * @param delimiter 鏂囨湰鍒嗛殧绗� + * @param topK 杩斿洖鐨勭粨鏋滄暟 + * @param blockSize 鏂囨湰鍧楀ぇ灏� + * @param overlapChars 閲嶅彔瀛楃鏁� + * @return 杩斿洖璁$畻寰楀埌鐨� Top K 宓屽叆鍚戦噺鍒楄〃 + */ + + public static List<List<Double>> getTopKEmbeddings( + List<String> queries, + String modelName, + String delimiter, + int topK, + int blockSize, + int overlapChars) { + + modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 榛樿妯″瀷鍚嶇О + delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // 榛樿鍒嗛殧绗� + topK = (topK > 0) ? topK : 3; // 榛樿杩斿洖 3 涓粨鏋� + blockSize = (blockSize > 0) ? blockSize : 500; // 榛樿鏂囨湰鍧楀ぇ灏忎负 500 + overlapChars = (overlapChars > 0) ? overlapChars : 50; // 榛樿閲嶅彔瀛楃鏁颁负 50 + + // 鍒涘缓 Retrofit 瀹炰緥 + Retrofit retrofit = getRetrofitInstance(); + + // 鍒涘缓 SearchService 鎺ュ彛 + SearchService service = retrofit.create(SearchService.class); + + // 鍒涘缓璇锋眰瀵硅薄 LocalModelsSearchRequest + LocalModelsSearchRequest request = new LocalModelsSearchRequest( + queries, // 鏌ヨ鏂囨湰鍒楄〃 + modelName, // 妯″瀷鍚嶇О + delimiter, // 鏂囨湰鍒嗛殧绗� + topK, // 杩斿洖鐨勭粨鏋滄暟 + blockSize, // 鏂囨湰鍧楀ぇ灏� + overlapChars // 閲嶅彔瀛楃鏁� + ); + + final CountDownLatch latch = new CountDownLatch(1); // 鍒涘缓涓�涓� CountDownLatch + final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 浣跨敤鏁扮粍鏉ュ瓨鍌ㄧ粨鏋滐紙鍥犱负 Java 涓嶆敮鎸佺洿鎺ヤ慨鏀� List锛� + + // 鍙戣捣寮傛璇锋眰 + service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() { + @Override + public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) { + if (response.isSuccessful()) { + LocalModelsSearchResponse searchResponse = response.body(); + if (searchResponse != null) { + topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // 鑾峰彇缁撴灉 + log.info("Successfully retrieved embeddings"); + } else { + log.error("Response body is null"); + } + } else { + log.error("Request failed. HTTP error code: " + response.code()); + } + latch.countDown(); // 璇锋眰瀹屾垚锛屽噺灏戣鏁� + } + + @Override + public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) { + t.printStackTrace(); + log.error("Request failed: ", t); + latch.countDown(); // 璇锋眰澶辫触锛屽噺灏戣鏁� + } + }); + + try { + latch.await(); // 绛夊緟璇锋眰瀹屾垚 + } catch (InterruptedException e) { + e.printStackTrace(); + } + + return topKEmbeddings[0]; // 杩斿洖缁撴灉 + } + +// public static void main(String[] args) { +// // 绀轰緥璋冪敤 +// List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries."); +// String modelName = "msmarco-distilbert-base-tas-b"; +// String delimiter = "."; +// int topK = 3; +// int blockSize = 500; +// int overlapChars = 50; +// +// List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars); +// +// // 鎵撳嵃缁撴灉 +// if (topKEmbeddings != null) { +// System.out.println("Top K embeddings: "); +// for (List<Double> embedding : topKEmbeddings) { +// System.out.println(embedding); +// } +// } else { +// System.out.println("No embeddings returned."); +// } +// } + + +// public static void main(String[] args) { +// // 鍒涘缓 Retrofit 瀹炰緥 +// Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance(); +// +// // 鍒涘缓 SearchService 鎺ュ彛 +// SearchService service = retrofit.create(SearchService.class); +// +// // 鍒涘缓璇锋眰瀵硅薄 LocalModelsSearchRequest +// LocalModelsSearchRequest request = new LocalModelsSearchRequest( +// Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 鏌ヨ鏂囨湰鍒楄〃 +// "msmarco-distilbert-base-tas-b", // 妯″瀷鍚嶇О +// ".", // 鍒嗛殧绗� +// 3, // 杩斿洖鐨勭粨鏋滄暟 +// 500, // 鏂囨湰鍧楀ぇ灏� +// 50 // 閲嶅彔瀛楃鏁� +// ); +// +// // 鍙戣捣璇锋眰 +// service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() { +// @Override +// public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) { +// if (response.isSuccessful()) { +// LocalModelsSearchResponse searchResponse = response.body(); +// System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging +// +// if (searchResponse != null) { +// // If the response is not null, process it. +// // Example: Extract the embeddings and print them +// List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings(); +// if (topKEmbeddings != null) { +// // Print the Top K embeddings +// +// } else { +// System.err.println("Top K embeddings are null"); +// } +// +// // If there is more information you want to process, handle it here +// +// } else { +// System.err.println("Response body is null"); +// } +// } else { +// System.err.println("Request failed. HTTP error code: " + response.code()); +// log.error("Failed to retrieve data. HTTP error code: " + response.code()); +// } +// } +// +// @Override +// public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) { +// // 璇锋眰澶辫触锛屾墦鍗伴敊璇� +// t.printStackTrace(); +// log.error("Request failed: ", t); +// } +// }); +// } + +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java new file mode 100644 index 0000000..3fa131e --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java @@ -0,0 +1,25 @@ +package org.ruoyi.common.chat.localModels; + + + +import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.POST; +/** + * @program: RUOYIAI + * @ClassName SearchService + * @description: 璇锋眰妯″瀷 + * @author: hejh + * @create: 2025-03-15 17:27 + * @Version 1.0 + **/ + + +public interface SearchService { + @POST("/vectorize") // 涓� Flask 鏈嶅姟涓殑璺敱鍖归厤 + Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request); +} + + diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java new file mode 100644 index 0000000..d7dff25 --- /dev/null +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java @@ -0,0 +1,92 @@ +package org.ruoyi.knowledge.chain.vectorizer; + +import jakarta.annotation.Resource; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.config.ChatConfig; +import org.ruoyi.common.chat.localModels.LocalModelsofitClient; +import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; +import org.ruoyi.knowledge.service.IKnowledgeInfoService; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +@Component +@Slf4j +@RequiredArgsConstructor +public class LocalModelsVectorization { + @Resource + private IKnowledgeInfoService knowledgeInfoService; + + @Resource + private LocalModelsofitClient localModelsofitClient; + + @Getter + private OpenAiStreamClient openAiStreamClient; + + private final ChatConfig chatConfig; + + /** + * 鎵归噺鍚戦噺鍖� + * + * @param chunkList 鏂囨湰鍧楀垪琛� + * @param kid 鐭ヨ瘑 ID + * @return 鍚戦噺鍖栫粨鏋� + */ + + public List<List<Double>> batchVectorization(List<String> chunkList, String kid) { + logVectorizationRequest(kid, chunkList); // 鍦ㄥ悜閲忓寲寮�濮嬪墠璁板綍鏃ュ織 + openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 鑾峰彇 OpenAi 瀹㈡埛绔� + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 鏌ヨ鐭ヨ瘑淇℃伅 + // 璋冪敤 localModelsofitClient 鑾峰彇 Top K 宓屽叆鍚戦噺 + try { + return localModelsofitClient.getTopKEmbeddings( + chunkList, + knowledgeInfoVo.getVector(), + knowledgeInfoVo.getKnowledgeSeparator(), + knowledgeInfoVo.getRetrieveLimit(), + knowledgeInfoVo.getTextBlockSize(), + knowledgeInfoVo.getOverlapChar() + ); + } catch (Exception e) { + log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e); + throw new RuntimeException("Batch vectorization failed", e); + } + } + + /** + * 鍗曚竴鏂囨湰鍧楀悜閲忓寲 + * + * @param chunk 鍗曚竴鏂囨湰鍧� + * @param kid 鐭ヨ瘑 ID + * @return 鍚戦噺鍖栫粨鏋� + */ + + public List<Double> singleVectorization(String chunk, String kid) { + List<String> chunkList = new ArrayList<>(); + chunkList.add(chunk); + + // 璋冪敤鎵归噺鍚戦噺鍖栨柟娉� + List<List<Double>> vectorList = batchVectorization(chunkList, kid); + + if (vectorList.isEmpty()) { + log.warn("Vectorization returned empty list for chunk: {}", chunk); + return new ArrayList<>(); + } + + return vectorList.get(0); // 杩斿洖绗竴涓悜閲� + } + + /** + * 鎻愪緵鏇寸畝娲佺殑鏃ュ織璁板綍鏂规硶 + * + * @param kid 鐭ヨ瘑 ID + * @param chunkList 鏂囨湰鍧楀垪琛� + */ + private void logVectorizationRequest(String kid, List<String> chunkList) { + log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size()); + } +} diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java index 0f2d0ba..764c2c1 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java @@ -18,6 +18,7 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; @Component @Slf4j @@ -27,6 +28,9 @@ @Lazy @Resource private IKnowledgeInfoService knowledgeInfoService; + @Lazy + @Resource + private LocalModelsVectorization localModelsVectorization; @Getter private OpenAiStreamClient openAiStreamClient; @@ -35,25 +39,63 @@ @Override public List<List<Double>> batchVectorization(List<String> chunkList, String kid) { - openAiStreamClient = chatConfig.getOpenAiStreamClient(); - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); - Embedding embedding = Embedding.builder() - .input(chunkList) - .model(knowledgeInfoVo.getVectorModel()) - .build(); - EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); List<List<Double>> vectorList = new ArrayList<>(); - embeddings.getData().forEach(data -> { - List<BigDecimal> vector = data.getEmbedding(); - List<Double> doubleVector = new ArrayList<>(); - for (BigDecimal bd : vector) { - doubleVector.add(bd.doubleValue()); - } - vectorList.add(doubleVector); - }); + + // 鑾峰彇鐭ヨ瘑搴撲俊鎭� + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); + + // 濡傛灉浣跨敤鏈湴妯″瀷 + try { + return localModelsVectorization.batchVectorization(chunkList, kid); + } catch (Exception e) { + log.error("Local models vectorization failed, falling back to OpenAI embeddings", e); + } + + // 濡傛灉鏈湴妯″瀷澶辫触锛屽垯璋冪敤 OpenAI 鏈嶅姟杩涜鍚戦噺鍖� + Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo); + EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); + + // 澶勭悊 OpenAI 杩斿洖鐨勫祵鍏ユ暟鎹� + vectorList = processOpenAiEmbeddings(embeddings); + return vectorList; } + /** + * 鏋勫缓 Embedding 瀵硅薄 + */ + private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) { + return Embedding.builder() + .input(chunkList) + .model(knowledgeInfoVo.getVectorModel()) + .build(); + } + + /** + * 澶勭悊 OpenAI 杩斿洖鐨勫祵鍏ユ暟鎹� + */ + private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) { + List<List<Double>> vectorList = new ArrayList<>(); + + embeddings.getData().forEach(data -> { + List<BigDecimal> vector = data.getEmbedding(); + List<Double> doubleVector = convertToDoubleList(vector); + vectorList.add(doubleVector); + }); + + return vectorList; + } + + /** + * 灏� BigDecimal 杞崲涓� Double 鍒楄〃 + */ + private List<Double> convertToDoubleList(List<BigDecimal> vector) { + return vector.stream() + .map(BigDecimal::doubleValue) + .collect(Collectors.toList()); + } + + @Override public List<Double> singleVectorization(String chunk, String kid) { List<String> chunkList = new ArrayList<>(); diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java new file mode 100644 index 0000000..a9d370d --- /dev/null +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java @@ -0,0 +1,15 @@ +package org.ruoyi.knowledge.chain.vectorizer; + +public enum VectorizationType { + OPENAI, // OpenAI 鍚戦噺鍖� + LOCAL; // 鏈湴妯″瀷鍚戦噺鍖� + + public static VectorizationType fromString(String type) { + for (VectorizationType v : values()) { + if (v.name().equalsIgnoreCase(type)) { + return v; + } + } + throw new IllegalArgumentException("Unknown VectorizationType: " + type); + } +} -- Gitblit v1.9.3