办学质量监测教学评价系统
jiahao.he@vtradex.com
2025-03-16 4967c3f906b32184aa2e617fea034900bb564aa2
本地向量化
已修改1个文件
已添加6个文件
460 ■■■■■ 文件已修改
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java 38 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java 198 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java 92 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java 72 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,38 @@
package org.ruoyi.common.chat.entity.models;
import lombok.Data;
import java.util.List;
/**
 * @program: RUOYIAI
 * @ClassName LocalModelsSearchRequest
 * @description:
 * @author: hejh
 * @create: 2025-03-15 17:22
 * @Version 1.0
 **/
@Data
public class LocalModelsSearchRequest {
    private List<String> text;
    private String model_name;
    private String delimiter;
    private int k;
    private int block_size;
    private int overlap_chars;
    // æž„造函数、Getter å’Œ Setter
    public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) {
        this.text = text;
        this.model_name = model_name;
        this.delimiter = delimiter;
        this.k = k;
        this.block_size = block_size;
        this.overlap_chars = overlap_chars;
    }
}
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,20 @@
package org.ruoyi.common.chat.entity.models;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class LocalModelsSearchResponse {
    @JsonProperty("topKEmbeddings")
    private List<List<List<Double>>> topKEmbeddings;  // å¤„理三层嵌套数组
    // é»˜è®¤æž„造函数
    public LocalModelsSearchResponse() {}
}
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,198 @@
package org.ruoyi.common.chat.localModels;
import io.micrometer.common.util.StringUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
import org.springframework.stereotype.Service;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.converter.jackson.JacksonConverterFactory;
import java.util.List;
import java.util.concurrent.CountDownLatch;
@Slf4j
@Service
public class LocalModelsofitClient {
    private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask æœåŠ¡çš„ URL
    private static Retrofit retrofit = null;
    // èŽ·å– Retrofit å®žä¾‹
    public static Retrofit getRetrofitInstance() {
        if (retrofit == null) {
            OkHttpClient client = new OkHttpClient.Builder()
                    .build();
            retrofit = new Retrofit.Builder()
                    .baseUrl(BASE_URL)
                    .client(client)
                    .addConverterFactory(JacksonConverterFactory.create()) // ä½¿ç”¨ Jackson å¤„理 JSON è½¬æ¢
                    .build();
        }
        return retrofit;
    }
    /**
     * å‘ Flask æœåŠ¡å‘é€æ–‡æœ¬å‘é‡åŒ–è¯·æ±‚
     *
     * @param queries æŸ¥è¯¢æ–‡æœ¬åˆ—表
     * @param modelName æ¨¡åž‹åç§°
     * @param delimiter æ–‡æœ¬åˆ†éš”符
     * @param topK è¿”回的结果数
     * @param blockSize æ–‡æœ¬å—大小
     * @param overlapChars é‡å å­—符数
     * @return è¿”回计算得到的 Top K åµŒå…¥å‘量列表
     */
    public static List<List<Double>> getTopKEmbeddings(
            List<String> queries,
            String modelName,
            String delimiter,
            int topK,
            int blockSize,
            int overlapChars) {
        modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // é»˜è®¤æ¨¡åž‹åç§°
        delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : ".";                             // é»˜è®¤åˆ†éš”符
        topK = (topK > 0) ? topK : 3;                                                  // é»˜è®¤è¿”回 3 ä¸ªç»“æžœ
        blockSize = (blockSize > 0) ? blockSize : 500;                                 // é»˜è®¤æ–‡æœ¬å—大小为 500
        overlapChars = (overlapChars > 0) ? overlapChars : 50;                         // é»˜è®¤é‡å å­—符数为 50
        // åˆ›å»º Retrofit å®žä¾‹
        Retrofit retrofit = getRetrofitInstance();
        // åˆ›å»º SearchService æŽ¥å£
        SearchService service = retrofit.create(SearchService.class);
        // åˆ›å»ºè¯·æ±‚对象 LocalModelsSearchRequest
        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
                queries,            // æŸ¥è¯¢æ–‡æœ¬åˆ—表
                modelName,          // æ¨¡åž‹åç§°
                delimiter,          // æ–‡æœ¬åˆ†éš”符
                topK,               // è¿”回的结果数
                blockSize,          // æ–‡æœ¬å—大小
                overlapChars        // é‡å å­—符数
        );
        final CountDownLatch latch = new CountDownLatch(1);  // åˆ›å»ºä¸€ä¸ª CountDownLatch
        final List<List<Double>>[] topKEmbeddings = new List[]{null}; // ä½¿ç”¨æ•°ç»„来存储结果(因为 Java ä¸æ”¯æŒç›´æŽ¥ä¿®æ”¹ List)
        // å‘起异步请求
        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
            @Override
            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
                if (response.isSuccessful()) {
                    LocalModelsSearchResponse searchResponse = response.body();
                    if (searchResponse != null) {
                        topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0);  // èŽ·å–ç»“æžœ
                        log.info("Successfully retrieved embeddings");
                    } else {
                        log.error("Response body is null");
                    }
                } else {
                    log.error("Request failed. HTTP error code: " + response.code());
                }
                latch.countDown();  // è¯·æ±‚完成,减少计数
            }
            @Override
            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
                t.printStackTrace();
                log.error("Request failed: ", t);
                latch.countDown();  // è¯·æ±‚失败,减少计数
            }
        });
        try {
            latch.await();  // ç­‰å¾…请求完成
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return topKEmbeddings[0];  // è¿”回结果
    }
//    public static void main(String[] args) {
//        // ç¤ºä¾‹è°ƒç”¨
//        List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
//        String modelName = "msmarco-distilbert-base-tas-b";
//        String delimiter = ".";
//        int topK = 3;
//        int blockSize = 500;
//        int overlapChars = 50;
//
//        List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
//
//        // æ‰“印结果
//        if (topKEmbeddings != null) {
//            System.out.println("Top K embeddings: ");
//            for (List<Double> embedding : topKEmbeddings) {
//                System.out.println(embedding);
//            }
//        } else {
//            System.out.println("No embeddings returned.");
//        }
//    }
//    public static void main(String[] args) {
//        // åˆ›å»º Retrofit å®žä¾‹
//        Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
//
//        // åˆ›å»º SearchService æŽ¥å£
//        SearchService service = retrofit.create(SearchService.class);
//
//        // åˆ›å»ºè¯·æ±‚对象 LocalModelsSearchRequest
//        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
//                Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // æŸ¥è¯¢æ–‡æœ¬åˆ—表
//                "msmarco-distilbert-base-tas-b",  // æ¨¡åž‹åç§°
//                ".",  // åˆ†éš”符
//                3,  // è¿”回的结果数
//                500,  // æ–‡æœ¬å—大小
//                50  // é‡å å­—符数
//        );
//
//        // å‘起请求
//        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
//            @Override
//            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
//                if (response.isSuccessful()) {
//                    LocalModelsSearchResponse searchResponse = response.body();
//                    System.out.println("Response Body: " + response.body());  // Print the whole response body for debugging
//
//                    if (searchResponse != null) {
//                        // If the response is not null, process it.
//                        // Example: Extract the embeddings and print them
//                        List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
//                        if (topKEmbeddings != null) {
//                            // Print the Top K embeddings
//
//                        } else {
//                            System.err.println("Top K embeddings are null");
//                        }
//
//                        // If there is more information you want to process, handle it here
//
//                    } else {
//                        System.err.println("Response body is null");
//                    }
//                } else {
//                    System.err.println("Request failed. HTTP error code: " + response.code());
//                    log.error("Failed to retrieve data. HTTP error code: " + response.code());
//                }
//            }
//
//            @Override
//            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
//                // è¯·æ±‚失败,打印错误
//                t.printStackTrace();
//                log.error("Request failed: ", t);
//            }
//        });
//    }
}
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,25 @@
package org.ruoyi.common.chat.localModels;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.POST;
/**
 * @program: RUOYIAI
 * @ClassName SearchService
 * @description: è¯·æ±‚模型
 * @author: hejh
 * @create: 2025-03-15 17:27
 * @Version 1.0
 **/
public interface SearchService {
    @POST("/vectorize") // ä¸Ž Flask æœåŠ¡ä¸­çš„è·¯ç”±åŒ¹é…
    Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request);
}
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,92 @@
package org.ruoyi.knowledge.chain.vectorizer;
import jakarta.annotation.Resource;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.config.ChatConfig;
import org.ruoyi.common.chat.localModels.LocalModelsofitClient;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Component
@Slf4j
@RequiredArgsConstructor
public class LocalModelsVectorization   {
    @Resource
    private IKnowledgeInfoService knowledgeInfoService;
    @Resource
    private LocalModelsofitClient localModelsofitClient;
    @Getter
    private OpenAiStreamClient openAiStreamClient;
    private final ChatConfig chatConfig;
    /**
     * æ‰¹é‡å‘量化
     *
     * @param chunkList æ–‡æœ¬å—列表
     * @param kid çŸ¥è¯† ID
     * @return å‘量化结果
     */
    public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
        logVectorizationRequest(kid, chunkList);  // åœ¨å‘量化开始前记录日志
        openAiStreamClient = chatConfig.getOpenAiStreamClient(); // èŽ·å– OpenAi å®¢æˆ·ç«¯
        KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // æŸ¥è¯¢çŸ¥è¯†ä¿¡æ¯
        // è°ƒç”¨ localModelsofitClient èŽ·å– Top K åµŒå…¥å‘量
        try {
            return localModelsofitClient.getTopKEmbeddings(
                    chunkList,
                    knowledgeInfoVo.getVector(),
                    knowledgeInfoVo.getKnowledgeSeparator(),
                    knowledgeInfoVo.getRetrieveLimit(),
                    knowledgeInfoVo.getTextBlockSize(),
                    knowledgeInfoVo.getOverlapChar()
            );
        } catch (Exception e) {
            log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e);
            throw new RuntimeException("Batch vectorization failed", e);
        }
    }
    /**
     * å•一文本块向量化
     *
     * @param chunk å•一文本块
     * @param kid çŸ¥è¯† ID
     * @return å‘量化结果
     */
    public List<Double> singleVectorization(String chunk, String kid) {
        List<String> chunkList = new ArrayList<>();
        chunkList.add(chunk);
        // è°ƒç”¨æ‰¹é‡å‘量化方法
        List<List<Double>> vectorList = batchVectorization(chunkList, kid);
        if (vectorList.isEmpty()) {
            log.warn("Vectorization returned empty list for chunk: {}", chunk);
            return new ArrayList<>();
        }
        return vectorList.get(0); // è¿”回第一个向量
    }
    /**
     * æä¾›æ›´ç®€æ´çš„æ—¥å¿—记录方法
     *
     * @param kid çŸ¥è¯† ID
     * @param chunkList æ–‡æœ¬å—列表
     */
    private void logVectorizationRequest(String kid, List<String> chunkList) {
        log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size());
    }
}
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java
@@ -18,6 +18,7 @@
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
@@ -27,6 +28,9 @@
    @Lazy
    @Resource
    private IKnowledgeInfoService knowledgeInfoService;
    @Lazy
    @Resource
    private LocalModelsVectorization localModelsVectorization;
    @Getter
    private OpenAiStreamClient openAiStreamClient;
@@ -35,25 +39,63 @@
    @Override
    public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
        openAiStreamClient = chatConfig.getOpenAiStreamClient();
        KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
        Embedding embedding = Embedding.builder()
            .input(chunkList)
            .model(knowledgeInfoVo.getVectorModel())
            .build();
        EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
        List<List<Double>> vectorList = new ArrayList<>();
        embeddings.getData().forEach(data -> {
            List<BigDecimal> vector = data.getEmbedding();
            List<Double> doubleVector = new ArrayList<>();
            for (BigDecimal bd : vector) {
                doubleVector.add(bd.doubleValue());
            }
            vectorList.add(doubleVector);
        });
        // èŽ·å–çŸ¥è¯†åº“ä¿¡æ¯
        KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
        // å¦‚果使用本地模型
        try {
            return localModelsVectorization.batchVectorization(chunkList, kid);
        } catch (Exception e) {
            log.error("Local models vectorization failed, falling back to OpenAI embeddings", e);
        }
        // å¦‚果本地模型失败,则调用 OpenAI æœåŠ¡è¿›è¡Œå‘é‡åŒ–
        Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
        EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
        // å¤„理 OpenAI è¿”回的嵌入数据
        vectorList = processOpenAiEmbeddings(embeddings);
        return vectorList;
    }
    /**
     * æž„建 Embedding å¯¹è±¡
     */
    private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
        return Embedding.builder()
                .input(chunkList)
                .model(knowledgeInfoVo.getVectorModel())
                .build();
    }
    /**
     * å¤„理 OpenAI è¿”回的嵌入数据
     */
    private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
        List<List<Double>> vectorList = new ArrayList<>();
        embeddings.getData().forEach(data -> {
            List<BigDecimal> vector = data.getEmbedding();
            List<Double> doubleVector = convertToDoubleList(vector);
            vectorList.add(doubleVector);
        });
        return vectorList;
    }
    /**
     * å°† BigDecimal è½¬æ¢ä¸º Double åˆ—表
     */
    private List<Double> convertToDoubleList(List<BigDecimal> vector) {
        return vector.stream()
                .map(BigDecimal::doubleValue)
                .collect(Collectors.toList());
    }
    @Override
    public List<Double> singleVectorization(String chunk, String kid) {
        List<String> chunkList = new ArrayList<>();
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,15 @@
package org.ruoyi.knowledge.chain.vectorizer;
public enum VectorizationType {
    OPENAI,    // OpenAI å‘量化
    LOCAL;     // æœ¬åœ°æ¨¡åž‹å‘量化
    public static VectorizationType fromString(String type) {
        for (VectorizationType v : values()) {
            if (v.name().equalsIgnoreCase(type)) {
                return v;
            }
        }
        throw new IllegalArgumentException("Unknown VectorizationType: " + type);
    }
}