ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java
¶Ô±ÈÐÂÎļþ @@ -0,0 +1,38 @@ package org.ruoyi.common.chat.entity.models; import lombok.Data; import java.util.List; /** * @program: RUOYIAI * @ClassName LocalModelsSearchRequest * @description: * @author: hejh * @create: 2025-03-15 17:22 * @Version 1.0 **/ @Data public class LocalModelsSearchRequest { private List<String> text; private String model_name; private String delimiter; private int k; private int block_size; private int overlap_chars; // æé 彿°ãGetter å Setter public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) { this.text = text; this.model_name = model_name; this.delimiter = delimiter; this.k = k; this.block_size = block_size; this.overlap_chars = overlap_chars; } } ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java
¶Ô±ÈÐÂÎļþ @@ -0,0 +1,20 @@ package org.ruoyi.common.chat.entity.models; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.Data; import java.util.List; @Data @JsonIgnoreProperties(ignoreUnknown = true) public class LocalModelsSearchResponse { @JsonProperty("topKEmbeddings") private List<List<List<Double>>> topKEmbeddings; // å¤çä¸å±åµå¥æ°ç» // é»è®¤æé 彿° public LocalModelsSearchResponse() {} } ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java
¶Ô±ÈÐÂÎļþ @@ -0,0 +1,198 @@ package org.ruoyi.common.chat.localModels; import io.micrometer.common.util.StringUtils; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; import org.springframework.stereotype.Service; import retrofit2.Call; import retrofit2.Callback; import retrofit2.Response; import retrofit2.Retrofit; import retrofit2.converter.jackson.JacksonConverterFactory; import java.util.List; import java.util.concurrent.CountDownLatch; @Slf4j @Service public class LocalModelsofitClient { private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask æå¡ç URL private static Retrofit retrofit = null; // è·å Retrofit å®ä¾ public static Retrofit getRetrofitInstance() { if (retrofit == null) { OkHttpClient client = new OkHttpClient.Builder() .build(); retrofit = new Retrofit.Builder() .baseUrl(BASE_URL) .client(client) .addConverterFactory(JacksonConverterFactory.create()) // ä½¿ç¨ Jackson å¤ç JSON è½¬æ¢ .build(); } return retrofit; } /** * å Flask æå¡åéææ¬åéåè¯·æ± * * @param queries æ¥è¯¢ææ¬å表 * @param modelName 模ååç§° * @param delimiter ææ¬åé符 * @param topK è¿åçç»ææ° * @param blockSize ææ¬åå¤§å° * @param overlapChars éå åç¬¦æ° * @return è¿å计ç®å¾å°ç Top K åµå ¥åéå表 */ public static List<List<Double>> getTopKEmbeddings( List<String> queries, String modelName, String delimiter, int topK, int blockSize, int overlapChars) { modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // é»è®¤æ¨¡ååç§° delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // é»è®¤åé符 topK = (topK > 0) ? topK : 3; // é»è®¤è¿å 3 ä¸ªç»æ blockSize = (blockSize > 0) ? blockSize : 500; // é»è®¤ææ¬å大å°ä¸º 500 overlapChars = (overlapChars > 0) ? overlapChars : 50; // é»è®¤éå å符æ°ä¸º 50 // å建 Retrofit å®ä¾ Retrofit retrofit = getRetrofitInstance(); // å建 SearchService æ¥å£ SearchService service = retrofit.create(SearchService.class); // å建请æ±å¯¹è±¡ LocalModelsSearchRequest LocalModelsSearchRequest request = new LocalModelsSearchRequest( queries, // æ¥è¯¢ææ¬å表 modelName, // 模ååç§° delimiter, // ææ¬åé符 topK, // è¿åçç»ææ° blockSize, // ææ¬åå¤§å° overlapChars // éå åç¬¦æ° ); final CountDownLatch latch = new CountDownLatch(1); // å建ä¸ä¸ª CountDownLatch final List<List<Double>>[] topKEmbeddings = new List[]{null}; // ä½¿ç¨æ°ç»æ¥åå¨ç»æï¼å 为 Java 䏿¯æç´æ¥ä¿®æ¹ Listï¼ // åèµ·å¼æ¥è¯·æ± service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() { @Override public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) { if (response.isSuccessful()) { LocalModelsSearchResponse searchResponse = response.body(); if (searchResponse != null) { topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // è·åç»æ log.info("Successfully retrieved embeddings"); } else { log.error("Response body is null"); } } else { log.error("Request failed. HTTP error code: " + response.code()); } latch.countDown(); // 请æ±å®æï¼åå°è®¡æ° } @Override public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) { t.printStackTrace(); log.error("Request failed: ", t); latch.countDown(); // 请æ±å¤±è´¥ï¼åå°è®¡æ° } }); try { latch.await(); // çå¾ è¯·æ±å®æ } catch (InterruptedException e) { e.printStackTrace(); } return topKEmbeddings[0]; // è¿åç»æ } // public static void main(String[] args) { // // 示ä¾è°ç¨ // List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries."); // String modelName = "msmarco-distilbert-base-tas-b"; // String delimiter = "."; // int topK = 3; // int blockSize = 500; // int overlapChars = 50; // // List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars); // // // æå°ç»æ // if (topKEmbeddings != null) { // System.out.println("Top K embeddings: "); // for (List<Double> embedding : topKEmbeddings) { // System.out.println(embedding); // } // } else { // System.out.println("No embeddings returned."); // } // } // public static void main(String[] args) { // // å建 Retrofit å®ä¾ // Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance(); // // // å建 SearchService æ¥å£ // SearchService service = retrofit.create(SearchService.class); // // // å建请æ±å¯¹è±¡ LocalModelsSearchRequest // LocalModelsSearchRequest request = new LocalModelsSearchRequest( // Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // æ¥è¯¢ææ¬å表 // "msmarco-distilbert-base-tas-b", // 模ååç§° // ".", // åé符 // 3, // è¿åçç»ææ° // 500, // ææ¬åå¤§å° // 50 // éå åç¬¦æ° // ); // // // åèµ·è¯·æ± // service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() { // @Override // public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) { // if (response.isSuccessful()) { // LocalModelsSearchResponse searchResponse = response.body(); // System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging // // if (searchResponse != null) { // // If the response is not null, process it. // // Example: Extract the embeddings and print them // List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings(); // if (topKEmbeddings != null) { // // Print the Top K embeddings // // } else { // System.err.println("Top K embeddings are null"); // } // // // If there is more information you want to process, handle it here // // } else { // System.err.println("Response body is null"); // } // } else { // System.err.println("Request failed. HTTP error code: " + response.code()); // log.error("Failed to retrieve data. HTTP error code: " + response.code()); // } // } // // @Override // public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) { // // 请æ±å¤±è´¥ï¼æå°é误 // t.printStackTrace(); // log.error("Request failed: ", t); // } // }); // } } ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java
¶Ô±ÈÐÂÎļþ @@ -0,0 +1,25 @@ package org.ruoyi.common.chat.localModels; import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; import retrofit2.Call; import retrofit2.http.Body; import retrofit2.http.POST; /** * @program: RUOYIAI * @ClassName SearchService * @description: è¯·æ±æ¨¡å * @author: hejh * @create: 2025-03-15 17:27 * @Version 1.0 **/ public interface SearchService { @POST("/vectorize") // ä¸ Flask æå¡ä¸çè·¯ç±å¹é Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request); } ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java
¶Ô±ÈÐÂÎļþ @@ -0,0 +1,92 @@ package org.ruoyi.knowledge.chain.vectorizer; import jakarta.annotation.Resource; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.chat.config.ChatConfig; import org.ruoyi.common.chat.localModels.LocalModelsofitClient; import org.ruoyi.common.chat.openai.OpenAiStreamClient; import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; import org.ruoyi.knowledge.service.IKnowledgeInfoService; import org.springframework.stereotype.Component; import java.util.ArrayList; import java.util.List; @Component @Slf4j @RequiredArgsConstructor public class LocalModelsVectorization { @Resource private IKnowledgeInfoService knowledgeInfoService; @Resource private LocalModelsofitClient localModelsofitClient; @Getter private OpenAiStreamClient openAiStreamClient; private final ChatConfig chatConfig; /** * æ¹éåéå * * @param chunkList ææ¬åå表 * @param kid ç¥è¯ ID * @return åéåç»æ */ public List<List<Double>> batchVectorization(List<String> chunkList, String kid) { logVectorizationRequest(kid, chunkList); // å¨åéåå¼å§åè®°å½æ¥å¿ openAiStreamClient = chatConfig.getOpenAiStreamClient(); // è·å OpenAi 客æ·ç«¯ KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // æ¥è¯¢ç¥è¯ä¿¡æ¯ // è°ç¨ localModelsofitClient è·å Top K åµå ¥åé try { return localModelsofitClient.getTopKEmbeddings( chunkList, knowledgeInfoVo.getVector(), knowledgeInfoVo.getKnowledgeSeparator(), knowledgeInfoVo.getRetrieveLimit(), knowledgeInfoVo.getTextBlockSize(), knowledgeInfoVo.getOverlapChar() ); } catch (Exception e) { log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e); throw new RuntimeException("Batch vectorization failed", e); } } /** * å䏿æ¬ååéå * * @param chunk å䏿æ¬å * @param kid ç¥è¯ ID * @return åéåç»æ */ public List<Double> singleVectorization(String chunk, String kid) { List<String> chunkList = new ArrayList<>(); chunkList.add(chunk); // è°ç¨æ¹éåéåæ¹æ³ List<List<Double>> vectorList = batchVectorization(chunkList, kid); if (vectorList.isEmpty()) { log.warn("Vectorization returned empty list for chunk: {}", chunk); return new ArrayList<>(); } return vectorList.get(0); // è¿å第ä¸ä¸ªåé } /** * æä¾æ´ç®æ´çæ¥å¿è®°å½æ¹æ³ * * @param kid ç¥è¯ ID * @param chunkList ææ¬åå表 */ private void logVectorizationRequest(String kid, List<String> chunkList) { log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size()); } } ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java
@@ -18,6 +18,7 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @Component @Slf4j @@ -27,6 +28,9 @@ @Lazy @Resource private IKnowledgeInfoService knowledgeInfoService; @Lazy @Resource private LocalModelsVectorization localModelsVectorization; @Getter private OpenAiStreamClient openAiStreamClient; @@ -35,25 +39,63 @@ @Override public List<List<Double>> batchVectorization(List<String> chunkList, String kid) { openAiStreamClient = chatConfig.getOpenAiStreamClient(); KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); Embedding embedding = Embedding.builder() .input(chunkList) .model(knowledgeInfoVo.getVectorModel()) .build(); EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); List<List<Double>> vectorList = new ArrayList<>(); embeddings.getData().forEach(data -> { List<BigDecimal> vector = data.getEmbedding(); List<Double> doubleVector = new ArrayList<>(); for (BigDecimal bd : vector) { doubleVector.add(bd.doubleValue()); } vectorList.add(doubleVector); }); // è·åç¥è¯åºä¿¡æ¯ KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // å¦æä½¿ç¨æ¬å°æ¨¡å try { return localModelsVectorization.batchVectorization(chunkList, kid); } catch (Exception e) { log.error("Local models vectorization failed, falling back to OpenAI embeddings", e); } // 妿æ¬å°æ¨¡å失败ï¼åè°ç¨ OpenAI æå¡è¿è¡åéå Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo); EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); // å¤ç OpenAI è¿åçåµå ¥æ°æ® vectorList = processOpenAiEmbeddings(embeddings); return vectorList; } /** * æå»º Embedding 对象 */ private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) { return Embedding.builder() .input(chunkList) .model(knowledgeInfoVo.getVectorModel()) .build(); } /** * å¤ç OpenAI è¿åçåµå ¥æ°æ® */ private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) { List<List<Double>> vectorList = new ArrayList<>(); embeddings.getData().forEach(data -> { List<BigDecimal> vector = data.getEmbedding(); List<Double> doubleVector = convertToDoubleList(vector); vectorList.add(doubleVector); }); return vectorList; } /** * å° BigDecimal 转æ¢ä¸º Double å表 */ private List<Double> convertToDoubleList(List<BigDecimal> vector) { return vector.stream() .map(BigDecimal::doubleValue) .collect(Collectors.toList()); } @Override public List<Double> singleVectorization(String chunk, String kid) { List<String> chunkList = new ArrayList<>(); ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java
¶Ô±ÈÐÂÎļþ @@ -0,0 +1,15 @@ package org.ruoyi.knowledge.chain.vectorizer; public enum VectorizationType { OPENAI, // OpenAI åéå LOCAL; // æ¬å°æ¨¡ååéå public static VectorizationType fromString(String type) { for (VectorizationType v : values()) { if (v.name().equalsIgnoreCase(type)) { return v; } } throw new IllegalArgumentException("Unknown VectorizationType: " + type); } }