办学质量监测教学评价系统
ageerle
2025-03-17 a63cc487894bb6517b706655dbb2adc1f6f0e002
Merge pull request #16 from h794629435/main

本地向量化
已修改1个文件
已添加9个文件
584 ■■■■■ 文件已修改
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java 38 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java 198 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java 92 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java 56 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
script/docker/localModels/Dockerfile 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
script/docker/localModels/app.py 116 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
script/docker/localModels/requirements.txt 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,38 @@
package org.ruoyi.common.chat.entity.models;
import lombok.Data;
import java.util.List;
/**
 * @program: RUOYIAI
 * @ClassName LocalModelsSearchRequest
 * @description:
 * @author: hejh
 * @create: 2025-03-15 17:22
 * @Version 1.0
 **/
@Data
public class LocalModelsSearchRequest {
    private List<String> text;
    private String model_name;
    private String delimiter;
    private int k;
    private int block_size;
    private int overlap_chars;
    // æž„造函数、Getter å’Œ Setter
    public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) {
        this.text = text;
        this.model_name = model_name;
        this.delimiter = delimiter;
        this.k = k;
        this.block_size = block_size;
        this.overlap_chars = overlap_chars;
    }
}
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,20 @@
package org.ruoyi.common.chat.entity.models;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class LocalModelsSearchResponse {
    @JsonProperty("topKEmbeddings")
    private List<List<List<Double>>> topKEmbeddings;  // å¤„理三层嵌套数组
    // é»˜è®¤æž„造函数
    public LocalModelsSearchResponse() {}
}
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,198 @@
package org.ruoyi.common.chat.localModels;
import io.micrometer.common.util.StringUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
import org.springframework.stereotype.Service;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.converter.jackson.JacksonConverterFactory;
import java.util.List;
import java.util.concurrent.CountDownLatch;
@Slf4j
@Service
public class LocalModelsofitClient {
    private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask æœåŠ¡çš„ URL
    private static Retrofit retrofit = null;
    // èŽ·å– Retrofit å®žä¾‹
    public static Retrofit getRetrofitInstance() {
        if (retrofit == null) {
            OkHttpClient client = new OkHttpClient.Builder()
                    .build();
            retrofit = new Retrofit.Builder()
                    .baseUrl(BASE_URL)
                    .client(client)
                    .addConverterFactory(JacksonConverterFactory.create()) // ä½¿ç”¨ Jackson å¤„理 JSON è½¬æ¢
                    .build();
        }
        return retrofit;
    }
    /**
     * å‘ Flask æœåŠ¡å‘é€æ–‡æœ¬å‘é‡åŒ–è¯·æ±‚
     *
     * @param queries æŸ¥è¯¢æ–‡æœ¬åˆ—表
     * @param modelName æ¨¡åž‹åç§°
     * @param delimiter æ–‡æœ¬åˆ†éš”符
     * @param topK è¿”回的结果数
     * @param blockSize æ–‡æœ¬å—大小
     * @param overlapChars é‡å å­—符数
     * @return è¿”回计算得到的 Top K åµŒå…¥å‘量列表
     */
    public static List<List<Double>> getTopKEmbeddings(
            List<String> queries,
            String modelName,
            String delimiter,
            int topK,
            int blockSize,
            int overlapChars) {
        modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // é»˜è®¤æ¨¡åž‹åç§°
        delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : ".";                             // é»˜è®¤åˆ†éš”符
        topK = (topK > 0) ? topK : 3;                                                  // é»˜è®¤è¿”回 3 ä¸ªç»“æžœ
        blockSize = (blockSize > 0) ? blockSize : 500;                                 // é»˜è®¤æ–‡æœ¬å—大小为 500
        overlapChars = (overlapChars > 0) ? overlapChars : 50;                         // é»˜è®¤é‡å å­—符数为 50
        // åˆ›å»º Retrofit å®žä¾‹
        Retrofit retrofit = getRetrofitInstance();
        // åˆ›å»º SearchService æŽ¥å£
        SearchService service = retrofit.create(SearchService.class);
        // åˆ›å»ºè¯·æ±‚对象 LocalModelsSearchRequest
        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
                queries,            // æŸ¥è¯¢æ–‡æœ¬åˆ—表
                modelName,          // æ¨¡åž‹åç§°
                delimiter,          // æ–‡æœ¬åˆ†éš”符
                topK,               // è¿”回的结果数
                blockSize,          // æ–‡æœ¬å—大小
                overlapChars        // é‡å å­—符数
        );
        final CountDownLatch latch = new CountDownLatch(1);  // åˆ›å»ºä¸€ä¸ª CountDownLatch
        final List<List<Double>>[] topKEmbeddings = new List[]{null}; // ä½¿ç”¨æ•°ç»„来存储结果(因为 Java ä¸æ”¯æŒç›´æŽ¥ä¿®æ”¹ List)
        // å‘起异步请求
        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
            @Override
            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
                if (response.isSuccessful()) {
                    LocalModelsSearchResponse searchResponse = response.body();
                    if (searchResponse != null) {
                        topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0);  // èŽ·å–ç»“æžœ
                        log.info("Successfully retrieved embeddings");
                    } else {
                        log.error("Response body is null");
                    }
                } else {
                    log.error("Request failed. HTTP error code: " + response.code());
                }
                latch.countDown();  // è¯·æ±‚完成,减少计数
            }
            @Override
            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
                t.printStackTrace();
                log.error("Request failed: ", t);
                latch.countDown();  // è¯·æ±‚失败,减少计数
            }
        });
        try {
            latch.await();  // ç­‰å¾…请求完成
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return topKEmbeddings[0];  // è¿”回结果
    }
//    public static void main(String[] args) {
//        // ç¤ºä¾‹è°ƒç”¨
//        List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
//        String modelName = "msmarco-distilbert-base-tas-b";
//        String delimiter = ".";
//        int topK = 3;
//        int blockSize = 500;
//        int overlapChars = 50;
//
//        List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
//
//        // æ‰“印结果
//        if (topKEmbeddings != null) {
//            System.out.println("Top K embeddings: ");
//            for (List<Double> embedding : topKEmbeddings) {
//                System.out.println(embedding);
//            }
//        } else {
//            System.out.println("No embeddings returned.");
//        }
//    }
//    public static void main(String[] args) {
//        // åˆ›å»º Retrofit å®žä¾‹
//        Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
//
//        // åˆ›å»º SearchService æŽ¥å£
//        SearchService service = retrofit.create(SearchService.class);
//
//        // åˆ›å»ºè¯·æ±‚对象 LocalModelsSearchRequest
//        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
//                Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // æŸ¥è¯¢æ–‡æœ¬åˆ—表
//                "msmarco-distilbert-base-tas-b",  // æ¨¡åž‹åç§°
//                ".",  // åˆ†éš”符
//                3,  // è¿”回的结果数
//                500,  // æ–‡æœ¬å—大小
//                50  // é‡å å­—符数
//        );
//
//        // å‘起请求
//        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
//            @Override
//            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
//                if (response.isSuccessful()) {
//                    LocalModelsSearchResponse searchResponse = response.body();
//                    System.out.println("Response Body: " + response.body());  // Print the whole response body for debugging
//
//                    if (searchResponse != null) {
//                        // If the response is not null, process it.
//                        // Example: Extract the embeddings and print them
//                        List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
//                        if (topKEmbeddings != null) {
//                            // Print the Top K embeddings
//
//                        } else {
//                            System.err.println("Top K embeddings are null");
//                        }
//
//                        // If there is more information you want to process, handle it here
//
//                    } else {
//                        System.err.println("Response body is null");
//                    }
//                } else {
//                    System.err.println("Request failed. HTTP error code: " + response.code());
//                    log.error("Failed to retrieve data. HTTP error code: " + response.code());
//                }
//            }
//
//            @Override
//            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
//                // è¯·æ±‚失败,打印错误
//                t.printStackTrace();
//                log.error("Request failed: ", t);
//            }
//        });
//    }
}
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,25 @@
package org.ruoyi.common.chat.localModels;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.POST;
/**
 * @program: RUOYIAI
 * @ClassName SearchService
 * @description: è¯·æ±‚模型
 * @author: hejh
 * @create: 2025-03-15 17:27
 * @Version 1.0
 **/
public interface SearchService {
    @POST("/vectorize") // ä¸Ž Flask æœåŠ¡ä¸­çš„è·¯ç”±åŒ¹é…
    Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request);
}
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,92 @@
package org.ruoyi.knowledge.chain.vectorizer;
import jakarta.annotation.Resource;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.config.ChatConfig;
import org.ruoyi.common.chat.localModels.LocalModelsofitClient;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Component
@Slf4j
@RequiredArgsConstructor
public class LocalModelsVectorization   {
    @Resource
    private IKnowledgeInfoService knowledgeInfoService;
    @Resource
    private LocalModelsofitClient localModelsofitClient;
    @Getter
    private OpenAiStreamClient openAiStreamClient;
    private final ChatConfig chatConfig;
    /**
     * æ‰¹é‡å‘量化
     *
     * @param chunkList æ–‡æœ¬å—列表
     * @param kid çŸ¥è¯† ID
     * @return å‘量化结果
     */
    public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
        logVectorizationRequest(kid, chunkList);  // åœ¨å‘量化开始前记录日志
        openAiStreamClient = chatConfig.getOpenAiStreamClient(); // èŽ·å– OpenAi å®¢æˆ·ç«¯
        KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // æŸ¥è¯¢çŸ¥è¯†ä¿¡æ¯
        // è°ƒç”¨ localModelsofitClient èŽ·å– Top K åµŒå…¥å‘量
        try {
            return localModelsofitClient.getTopKEmbeddings(
                    chunkList,
                    knowledgeInfoVo.getVector(),
                    knowledgeInfoVo.getKnowledgeSeparator(),
                    knowledgeInfoVo.getRetrieveLimit(),
                    knowledgeInfoVo.getTextBlockSize(),
                    knowledgeInfoVo.getOverlapChar()
            );
        } catch (Exception e) {
            log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e);
            throw new RuntimeException("Batch vectorization failed", e);
        }
    }
    /**
     * å•一文本块向量化
     *
     * @param chunk å•一文本块
     * @param kid çŸ¥è¯† ID
     * @return å‘量化结果
     */
    public List<Double> singleVectorization(String chunk, String kid) {
        List<String> chunkList = new ArrayList<>();
        chunkList.add(chunk);
        // è°ƒç”¨æ‰¹é‡å‘量化方法
        List<List<Double>> vectorList = batchVectorization(chunkList, kid);
        if (vectorList.isEmpty()) {
            log.warn("Vectorization returned empty list for chunk: {}", chunk);
            return new ArrayList<>();
        }
        return vectorList.get(0); // è¿”回第一个向量
    }
    /**
     * æä¾›æ›´ç®€æ´çš„æ—¥å¿—记录方法
     *
     * @param kid çŸ¥è¯† ID
     * @param chunkList æ–‡æœ¬å—列表
     */
    private void logVectorizationRequest(String kid, List<String> chunkList) {
        log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size());
    }
}
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java
@@ -18,6 +18,7 @@
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
@@ -27,6 +28,9 @@
    @Lazy
    @Resource
    private IKnowledgeInfoService knowledgeInfoService;
    @Lazy
    @Resource
    private LocalModelsVectorization localModelsVectorization;
    @Getter
    private OpenAiStreamClient openAiStreamClient;
@@ -35,25 +39,63 @@
    @Override
    public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
        openAiStreamClient = chatConfig.getOpenAiStreamClient();
        List<List<Double>> vectorList = new ArrayList<>();
        // èŽ·å–çŸ¥è¯†åº“ä¿¡æ¯
        KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
        Embedding embedding = Embedding.builder()
        // å¦‚果使用本地模型
        try {
            return localModelsVectorization.batchVectorization(chunkList, kid);
        } catch (Exception e) {
            log.error("Local models vectorization failed, falling back to OpenAI embeddings", e);
        }
        // å¦‚果本地模型失败,则调用 OpenAI æœåŠ¡è¿›è¡Œå‘é‡åŒ–
        Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
        EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
        // å¤„理 OpenAI è¿”回的嵌入数据
        vectorList = processOpenAiEmbeddings(embeddings);
        return vectorList;
    }
    /**
     * æž„建 Embedding å¯¹è±¡
     */
    private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
        return Embedding.builder()
            .input(chunkList)
            .model(knowledgeInfoVo.getVectorModel())
            .build();
        EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
    }
    /**
     * å¤„理 OpenAI è¿”回的嵌入数据
     */
    private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
        List<List<Double>> vectorList = new ArrayList<>();
        embeddings.getData().forEach(data -> {
            List<BigDecimal> vector = data.getEmbedding();
            List<Double> doubleVector = new ArrayList<>();
            for (BigDecimal bd : vector) {
                doubleVector.add(bd.doubleValue());
            }
            List<Double> doubleVector = convertToDoubleList(vector);
            vectorList.add(doubleVector);
        });
        return vectorList;
    }
    /**
     * å°† BigDecimal è½¬æ¢ä¸º Double åˆ—表
     */
    private List<Double> convertToDoubleList(List<BigDecimal> vector) {
        return vector.stream()
                .map(BigDecimal::doubleValue)
                .collect(Collectors.toList());
    }
    @Override
    public List<Double> singleVectorization(String chunk, String kid) {
        List<String> chunkList = new ArrayList<>();
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,15 @@
package org.ruoyi.knowledge.chain.vectorizer;
public enum VectorizationType {
    OPENAI,    // OpenAI å‘量化
    LOCAL;     // æœ¬åœ°æ¨¡åž‹å‘量化
    public static VectorizationType fromString(String type) {
        for (VectorizationType v : values()) {
            if (v.name().equalsIgnoreCase(type)) {
                return v;
            }
        }
        throw new IllegalArgumentException("Unknown VectorizationType: " + type);
    }
}
script/docker/localModels/Dockerfile
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,21 @@
# ä½¿ç”¨å®˜æ–¹ Python ä½œä¸ºåŸºç¡€é•œåƒ
FROM python:3.8-slim
# è®¾ç½®å·¥ä½œç›®å½•为 /app
WORKDIR /app
# å¤åˆ¶å½“前目录下的所有文件到 Docker å®¹å™¨çš„ /app ç›®å½•
COPY . /app
# å®‰è£…应用依赖
RUN pip install --no-cache-dir -r requirements.txt
# æš´éœ² Flask åº”用使用的端口
EXPOSE 5000
# è®¾ç½®çŽ¯å¢ƒå˜é‡
ENV FLASK_APP=app.py
ENV FLASK_RUN_HOST=0.0.0.0
# å¯åЍ Flask åº”用
CMD ["flask", "run", "--host=0.0.0.0"]
script/docker/localModels/app.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,116 @@
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import json
app = Flask(__name__)
# åˆ›å»ºä¸€ä¸ªå…¨å±€çš„æ¨¡åž‹ç¼“存字典
model_cache = {}
# åˆ†å‰²æ–‡æœ¬å—
def split_text(text, block_size, overlap_chars, delimiter):
    chunks = text.split(delimiter)
    text_blocks = []
    current_block = ""
    for chunk in chunks:
        if len(current_block) + len(chunk) + 1 <= block_size:
            if current_block:
                current_block += " " + chunk
            else:
                current_block = chunk
        else:
            text_blocks.append(current_block)
            current_block = chunk
    if current_block:
        text_blocks.append(current_block)
    overlap_blocks = []
    for i in range(len(text_blocks)):
        if i > 0:
            overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i]
            overlap_blocks.append(overlap_block)
        overlap_blocks.append(text_blocks[i])
    return overlap_blocks
# æ–‡æœ¬å‘量化
def vectorize_text_blocks(text_blocks, model):
    return model.encode(text_blocks)
# æ–‡æœ¬æ£€ç´¢
def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model):
    # å°†çŸ¥è¯†åº“拆分为文本块
    text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter)
    # å‘量化文本块
    knowledge_vectors = vectorize_text_blocks(text_blocks, model)
    # å‘量化查询文本
    query_vector = model.encode([query]).reshape(1, -1)
    # è®¡ç®—相似度
    similarities = cosine_similarity(query_vector, knowledge_vectors)
    # èŽ·å–ç›¸ä¼¼åº¦æœ€é«˜çš„ k ä¸ªæ–‡æœ¬å—的索引
    top_k_indices = similarities[0].argsort()[-k:][::-1]
    # è¿”回文本块和它们的向量
    top_k_texts = [text_blocks[i] for i in top_k_indices]
    top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices]
    return top_k_texts, top_k_embeddings
@app.route('/vectorize', methods=['POST'])
def vectorize_text():
    # ä»Žè¯·æ±‚中获取 JSON æ•°æ®
    data = request.json
    print(f"Received request data: {data}")  # è°ƒè¯•输出请求数据
    text_list = data.get("text", [])
    model_name = data.get("model_name", "msmarco-distilbert-base-tas-b")  # é»˜è®¤æ¨¡åž‹
    delimiter = data.get("delimiter", "\n")  # é»˜è®¤åˆ†éš”符
    k = int(data.get("k", 3))  # é»˜è®¤æ£€ç´¢æ¡æ•°
    block_size = int(data.get("block_size", 500))  # é»˜è®¤æ–‡æœ¬å—大小
    overlap_chars = int(data.get("overlap_chars", 50))  # é»˜è®¤é‡å å­—符数
    if not text_list:
        return jsonify({"error": "Text is required."}), 400
    # æ£€æŸ¥æ¨¡åž‹æ˜¯å¦å·²ç»åŠ è½½
    if model_name not in model_cache:
        try:
            model = SentenceTransformer(model_name)
            model_cache[model_name] = model  # ç¼“存模型
        except Exception as e:
            return jsonify({"error": f"Failed to load model: {e}"}), 500
    model = model_cache[model_name]
    top_k_texts_all = []
    top_k_embeddings_all = []
    # å¦‚果只有一个查询文本
    if len(text_list) == 1:
        top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model)
        top_k_texts_all.append(top_k_texts)
        top_k_embeddings_all.append(top_k_embeddings)
    elif len(text_list) > 1:
        # å¦‚果多个查询文本,依次处理
        for query in text_list:
            top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model)
            top_k_texts_all.append(top_k_texts)
            top_k_embeddings_all.append(top_k_embeddings)
    # å°†åµŒå…¥å‘量(ndarray)转换为可序列化的列表
    top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all]
    print(f"Top K texts: {top_k_texts_all}")  # æ‰“印检索到的文本
    print(f"Top K embeddings: {top_k_embeddings_all}")  # æ‰“印检索到的向量
    # è¿”回 JSON æ ¼å¼çš„æ•°æ®
    return jsonify({
        "topKEmbeddings": top_k_embeddings_all  # è¿”回嵌入向量
    })
if __name__ == '__main__':
    app.run(host="0.0.0.0", port=5000, debug=True)
script/docker/localModels/requirements.txt
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,3 @@
Flask==2.0.3
sentence-transformers==2.2.0
scikit-learn==0.24.2