From a63cc487894bb6517b706655dbb2adc1f6f0e002 Mon Sep 17 00:00:00 2001 From: ageerle <32251822+ageerle@users.noreply.github.com> Date: 星期一, 17 三月 2025 09:22:53 +0800 Subject: [PATCH] Merge pull request #16 from h794629435/main --- ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java | 92 +++++++++ ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java | 15 + ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java | 20 ++ ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java | 198 +++++++++++++++++++ script/docker/localModels/app.py | 116 +++++++++++ ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java | 38 +++ ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java | 72 +++++- script/docker/localModels/requirements.txt | 3 ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java | 25 ++ script/docker/localModels/Dockerfile | 21 ++ 10 files changed, 585 insertions(+), 15 deletions(-) diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java new file mode 100644 index 0000000..4ca71ba --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java @@ -0,0 +1,38 @@ +package org.ruoyi.common.chat.entity.models; + +import lombok.Data; + +import java.util.List; + +/** + * @program: RUOYIAI + * @ClassName LocalModelsSearchRequest + * @description: + * @author: hejh + * @create: 2025-03-15 17:22 + * @Version 1.0 + **/ +@Data +public class LocalModelsSearchRequest { + + private List<String> text; + private String model_name; + private String delimiter; + private int k; + private int block_size; + private int overlap_chars; + + // 鏋勯�犲嚱鏁般�丟etter 鍜� Setter + public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) { + this.text = text; + this.model_name = model_name; + this.delimiter = delimiter; + this.k = k; + this.block_size = block_size; + this.overlap_chars = overlap_chars; + } + + +} + + diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java new file mode 100644 index 0000000..12025d5 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java @@ -0,0 +1,20 @@ +package org.ruoyi.common.chat.entity.models; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class LocalModelsSearchResponse { + @JsonProperty("topKEmbeddings") + + private List<List<List<Double>>> topKEmbeddings; // 澶勭悊涓夊眰宓屽鏁扮粍 + + // 榛樿鏋勯�犲嚱鏁� + public LocalModelsSearchResponse() {} + + + +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java new file mode 100644 index 0000000..606a7c2 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java @@ -0,0 +1,198 @@ +package org.ruoyi.common.chat.localModels; + +import io.micrometer.common.util.StringUtils; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; +import org.springframework.stereotype.Service; +import retrofit2.Call; +import retrofit2.Callback; +import retrofit2.Response; +import retrofit2.Retrofit; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.util.List; +import java.util.concurrent.CountDownLatch; + +@Slf4j +@Service +public class LocalModelsofitClient { + private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 鏈嶅姟鐨� URL + private static Retrofit retrofit = null; + + // 鑾峰彇 Retrofit 瀹炰緥 + public static Retrofit getRetrofitInstance() { + if (retrofit == null) { + OkHttpClient client = new OkHttpClient.Builder() + .build(); + + retrofit = new Retrofit.Builder() + .baseUrl(BASE_URL) + .client(client) + .addConverterFactory(JacksonConverterFactory.create()) // 浣跨敤 Jackson 澶勭悊 JSON 杞崲 + .build(); + } + return retrofit; + } + + /** + * 鍚� Flask 鏈嶅姟鍙戦�佹枃鏈悜閲忓寲璇锋眰 + * + * @param queries 鏌ヨ鏂囨湰鍒楄〃 + * @param modelName 妯″瀷鍚嶇О + * @param delimiter 鏂囨湰鍒嗛殧绗� + * @param topK 杩斿洖鐨勭粨鏋滄暟 + * @param blockSize 鏂囨湰鍧楀ぇ灏� + * @param overlapChars 閲嶅彔瀛楃鏁� + * @return 杩斿洖璁$畻寰楀埌鐨� Top K 宓屽叆鍚戦噺鍒楄〃 + */ + + public static List<List<Double>> getTopKEmbeddings( + List<String> queries, + String modelName, + String delimiter, + int topK, + int blockSize, + int overlapChars) { + + modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 榛樿妯″瀷鍚嶇О + delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // 榛樿鍒嗛殧绗� + topK = (topK > 0) ? topK : 3; // 榛樿杩斿洖 3 涓粨鏋� + blockSize = (blockSize > 0) ? blockSize : 500; // 榛樿鏂囨湰鍧楀ぇ灏忎负 500 + overlapChars = (overlapChars > 0) ? overlapChars : 50; // 榛樿閲嶅彔瀛楃鏁颁负 50 + + // 鍒涘缓 Retrofit 瀹炰緥 + Retrofit retrofit = getRetrofitInstance(); + + // 鍒涘缓 SearchService 鎺ュ彛 + SearchService service = retrofit.create(SearchService.class); + + // 鍒涘缓璇锋眰瀵硅薄 LocalModelsSearchRequest + LocalModelsSearchRequest request = new LocalModelsSearchRequest( + queries, // 鏌ヨ鏂囨湰鍒楄〃 + modelName, // 妯″瀷鍚嶇О + delimiter, // 鏂囨湰鍒嗛殧绗� + topK, // 杩斿洖鐨勭粨鏋滄暟 + blockSize, // 鏂囨湰鍧楀ぇ灏� + overlapChars // 閲嶅彔瀛楃鏁� + ); + + final CountDownLatch latch = new CountDownLatch(1); // 鍒涘缓涓�涓� CountDownLatch + final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 浣跨敤鏁扮粍鏉ュ瓨鍌ㄧ粨鏋滐紙鍥犱负 Java 涓嶆敮鎸佺洿鎺ヤ慨鏀� List锛� + + // 鍙戣捣寮傛璇锋眰 + service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() { + @Override + public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) { + if (response.isSuccessful()) { + LocalModelsSearchResponse searchResponse = response.body(); + if (searchResponse != null) { + topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // 鑾峰彇缁撴灉 + log.info("Successfully retrieved embeddings"); + } else { + log.error("Response body is null"); + } + } else { + log.error("Request failed. HTTP error code: " + response.code()); + } + latch.countDown(); // 璇锋眰瀹屾垚锛屽噺灏戣鏁� + } + + @Override + public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) { + t.printStackTrace(); + log.error("Request failed: ", t); + latch.countDown(); // 璇锋眰澶辫触锛屽噺灏戣鏁� + } + }); + + try { + latch.await(); // 绛夊緟璇锋眰瀹屾垚 + } catch (InterruptedException e) { + e.printStackTrace(); + } + + return topKEmbeddings[0]; // 杩斿洖缁撴灉 + } + +// public static void main(String[] args) { +// // 绀轰緥璋冪敤 +// List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries."); +// String modelName = "msmarco-distilbert-base-tas-b"; +// String delimiter = "."; +// int topK = 3; +// int blockSize = 500; +// int overlapChars = 50; +// +// List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars); +// +// // 鎵撳嵃缁撴灉 +// if (topKEmbeddings != null) { +// System.out.println("Top K embeddings: "); +// for (List<Double> embedding : topKEmbeddings) { +// System.out.println(embedding); +// } +// } else { +// System.out.println("No embeddings returned."); +// } +// } + + +// public static void main(String[] args) { +// // 鍒涘缓 Retrofit 瀹炰緥 +// Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance(); +// +// // 鍒涘缓 SearchService 鎺ュ彛 +// SearchService service = retrofit.create(SearchService.class); +// +// // 鍒涘缓璇锋眰瀵硅薄 LocalModelsSearchRequest +// LocalModelsSearchRequest request = new LocalModelsSearchRequest( +// Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 鏌ヨ鏂囨湰鍒楄〃 +// "msmarco-distilbert-base-tas-b", // 妯″瀷鍚嶇О +// ".", // 鍒嗛殧绗� +// 3, // 杩斿洖鐨勭粨鏋滄暟 +// 500, // 鏂囨湰鍧楀ぇ灏� +// 50 // 閲嶅彔瀛楃鏁� +// ); +// +// // 鍙戣捣璇锋眰 +// service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() { +// @Override +// public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) { +// if (response.isSuccessful()) { +// LocalModelsSearchResponse searchResponse = response.body(); +// System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging +// +// if (searchResponse != null) { +// // If the response is not null, process it. +// // Example: Extract the embeddings and print them +// List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings(); +// if (topKEmbeddings != null) { +// // Print the Top K embeddings +// +// } else { +// System.err.println("Top K embeddings are null"); +// } +// +// // If there is more information you want to process, handle it here +// +// } else { +// System.err.println("Response body is null"); +// } +// } else { +// System.err.println("Request failed. HTTP error code: " + response.code()); +// log.error("Failed to retrieve data. HTTP error code: " + response.code()); +// } +// } +// +// @Override +// public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) { +// // 璇锋眰澶辫触锛屾墦鍗伴敊璇� +// t.printStackTrace(); +// log.error("Request failed: ", t); +// } +// }); +// } + +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java new file mode 100644 index 0000000..3fa131e --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java @@ -0,0 +1,25 @@ +package org.ruoyi.common.chat.localModels; + + + +import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.POST; +/** + * @program: RUOYIAI + * @ClassName SearchService + * @description: 璇锋眰妯″瀷 + * @author: hejh + * @create: 2025-03-15 17:27 + * @Version 1.0 + **/ + + +public interface SearchService { + @POST("/vectorize") // 涓� Flask 鏈嶅姟涓殑璺敱鍖归厤 + Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request); +} + + diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java new file mode 100644 index 0000000..d7dff25 --- /dev/null +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java @@ -0,0 +1,92 @@ +package org.ruoyi.knowledge.chain.vectorizer; + +import jakarta.annotation.Resource; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.config.ChatConfig; +import org.ruoyi.common.chat.localModels.LocalModelsofitClient; +import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; +import org.ruoyi.knowledge.service.IKnowledgeInfoService; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +@Component +@Slf4j +@RequiredArgsConstructor +public class LocalModelsVectorization { + @Resource + private IKnowledgeInfoService knowledgeInfoService; + + @Resource + private LocalModelsofitClient localModelsofitClient; + + @Getter + private OpenAiStreamClient openAiStreamClient; + + private final ChatConfig chatConfig; + + /** + * 鎵归噺鍚戦噺鍖� + * + * @param chunkList 鏂囨湰鍧楀垪琛� + * @param kid 鐭ヨ瘑 ID + * @return 鍚戦噺鍖栫粨鏋� + */ + + public List<List<Double>> batchVectorization(List<String> chunkList, String kid) { + logVectorizationRequest(kid, chunkList); // 鍦ㄥ悜閲忓寲寮�濮嬪墠璁板綍鏃ュ織 + openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 鑾峰彇 OpenAi 瀹㈡埛绔� + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 鏌ヨ鐭ヨ瘑淇℃伅 + // 璋冪敤 localModelsofitClient 鑾峰彇 Top K 宓屽叆鍚戦噺 + try { + return localModelsofitClient.getTopKEmbeddings( + chunkList, + knowledgeInfoVo.getVector(), + knowledgeInfoVo.getKnowledgeSeparator(), + knowledgeInfoVo.getRetrieveLimit(), + knowledgeInfoVo.getTextBlockSize(), + knowledgeInfoVo.getOverlapChar() + ); + } catch (Exception e) { + log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e); + throw new RuntimeException("Batch vectorization failed", e); + } + } + + /** + * 鍗曚竴鏂囨湰鍧楀悜閲忓寲 + * + * @param chunk 鍗曚竴鏂囨湰鍧� + * @param kid 鐭ヨ瘑 ID + * @return 鍚戦噺鍖栫粨鏋� + */ + + public List<Double> singleVectorization(String chunk, String kid) { + List<String> chunkList = new ArrayList<>(); + chunkList.add(chunk); + + // 璋冪敤鎵归噺鍚戦噺鍖栨柟娉� + List<List<Double>> vectorList = batchVectorization(chunkList, kid); + + if (vectorList.isEmpty()) { + log.warn("Vectorization returned empty list for chunk: {}", chunk); + return new ArrayList<>(); + } + + return vectorList.get(0); // 杩斿洖绗竴涓悜閲� + } + + /** + * 鎻愪緵鏇寸畝娲佺殑鏃ュ織璁板綍鏂规硶 + * + * @param kid 鐭ヨ瘑 ID + * @param chunkList 鏂囨湰鍧楀垪琛� + */ + private void logVectorizationRequest(String kid, List<String> chunkList) { + log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size()); + } +} diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java index 0f2d0ba..764c2c1 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java @@ -18,6 +18,7 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; @Component @Slf4j @@ -27,6 +28,9 @@ @Lazy @Resource private IKnowledgeInfoService knowledgeInfoService; + @Lazy + @Resource + private LocalModelsVectorization localModelsVectorization; @Getter private OpenAiStreamClient openAiStreamClient; @@ -35,25 +39,63 @@ @Override public List<List<Double>> batchVectorization(List<String> chunkList, String kid) { - openAiStreamClient = chatConfig.getOpenAiStreamClient(); - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); - Embedding embedding = Embedding.builder() - .input(chunkList) - .model(knowledgeInfoVo.getVectorModel()) - .build(); - EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); List<List<Double>> vectorList = new ArrayList<>(); - embeddings.getData().forEach(data -> { - List<BigDecimal> vector = data.getEmbedding(); - List<Double> doubleVector = new ArrayList<>(); - for (BigDecimal bd : vector) { - doubleVector.add(bd.doubleValue()); - } - vectorList.add(doubleVector); - }); + + // 鑾峰彇鐭ヨ瘑搴撲俊鎭� + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); + + // 濡傛灉浣跨敤鏈湴妯″瀷 + try { + return localModelsVectorization.batchVectorization(chunkList, kid); + } catch (Exception e) { + log.error("Local models vectorization failed, falling back to OpenAI embeddings", e); + } + + // 濡傛灉鏈湴妯″瀷澶辫触锛屽垯璋冪敤 OpenAI 鏈嶅姟杩涜鍚戦噺鍖� + Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo); + EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); + + // 澶勭悊 OpenAI 杩斿洖鐨勫祵鍏ユ暟鎹� + vectorList = processOpenAiEmbeddings(embeddings); + return vectorList; } + /** + * 鏋勫缓 Embedding 瀵硅薄 + */ + private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) { + return Embedding.builder() + .input(chunkList) + .model(knowledgeInfoVo.getVectorModel()) + .build(); + } + + /** + * 澶勭悊 OpenAI 杩斿洖鐨勫祵鍏ユ暟鎹� + */ + private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) { + List<List<Double>> vectorList = new ArrayList<>(); + + embeddings.getData().forEach(data -> { + List<BigDecimal> vector = data.getEmbedding(); + List<Double> doubleVector = convertToDoubleList(vector); + vectorList.add(doubleVector); + }); + + return vectorList; + } + + /** + * 灏� BigDecimal 杞崲涓� Double 鍒楄〃 + */ + private List<Double> convertToDoubleList(List<BigDecimal> vector) { + return vector.stream() + .map(BigDecimal::doubleValue) + .collect(Collectors.toList()); + } + + @Override public List<Double> singleVectorization(String chunk, String kid) { List<String> chunkList = new ArrayList<>(); diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java new file mode 100644 index 0000000..a9d370d --- /dev/null +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java @@ -0,0 +1,15 @@ +package org.ruoyi.knowledge.chain.vectorizer; + +public enum VectorizationType { + OPENAI, // OpenAI 鍚戦噺鍖� + LOCAL; // 鏈湴妯″瀷鍚戦噺鍖� + + public static VectorizationType fromString(String type) { + for (VectorizationType v : values()) { + if (v.name().equalsIgnoreCase(type)) { + return v; + } + } + throw new IllegalArgumentException("Unknown VectorizationType: " + type); + } +} diff --git a/script/docker/localModels/Dockerfile b/script/docker/localModels/Dockerfile new file mode 100644 index 0000000..c988920 --- /dev/null +++ b/script/docker/localModels/Dockerfile @@ -0,0 +1,21 @@ +# 浣跨敤瀹樻柟 Python 浣滀负鍩虹闀滃儚 +FROM python:3.8-slim + +# 璁剧疆宸ヤ綔鐩綍涓� /app +WORKDIR /app + +# 澶嶅埗褰撳墠鐩綍涓嬬殑鎵�鏈夋枃浠跺埌 Docker 瀹瑰櫒鐨� /app 鐩綍 +COPY . /app + +# 瀹夎搴旂敤渚濊禆 +RUN pip install --no-cache-dir -r requirements.txt + +# 鏆撮湶 Flask 搴旂敤浣跨敤鐨勭鍙� +EXPOSE 5000 + +# 璁剧疆鐜鍙橀噺 +ENV FLASK_APP=app.py +ENV FLASK_RUN_HOST=0.0.0.0 + +# 鍚姩 Flask 搴旂敤 +CMD ["flask", "run", "--host=0.0.0.0"] diff --git a/script/docker/localModels/app.py b/script/docker/localModels/app.py new file mode 100644 index 0000000..645a9b4 --- /dev/null +++ b/script/docker/localModels/app.py @@ -0,0 +1,116 @@ +from flask import Flask, request, jsonify +from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity +import json + +app = Flask(__name__) + +# 鍒涘缓涓�涓叏灞�鐨勬ā鍨嬬紦瀛樺瓧鍏� +model_cache = {} + +# 鍒嗗壊鏂囨湰鍧� +def split_text(text, block_size, overlap_chars, delimiter): + chunks = text.split(delimiter) + text_blocks = [] + current_block = "" + + for chunk in chunks: + if len(current_block) + len(chunk) + 1 <= block_size: + if current_block: + current_block += " " + chunk + else: + current_block = chunk + else: + text_blocks.append(current_block) + current_block = chunk + if current_block: + text_blocks.append(current_block) + + overlap_blocks = [] + for i in range(len(text_blocks)): + if i > 0: + overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i] + overlap_blocks.append(overlap_block) + overlap_blocks.append(text_blocks[i]) + + return overlap_blocks + +# 鏂囨湰鍚戦噺鍖� +def vectorize_text_blocks(text_blocks, model): + return model.encode(text_blocks) + +# 鏂囨湰妫�绱� +def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model): + # 灏嗙煡璇嗗簱鎷嗗垎涓烘枃鏈潡 + text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter) + # 鍚戦噺鍖栨枃鏈潡 + knowledge_vectors = vectorize_text_blocks(text_blocks, model) + # 鍚戦噺鍖栨煡璇㈡枃鏈� + query_vector = model.encode([query]).reshape(1, -1) + # 璁$畻鐩镐技搴� + similarities = cosine_similarity(query_vector, knowledge_vectors) + # 鑾峰彇鐩镐技搴︽渶楂樼殑 k 涓枃鏈潡鐨勭储寮� + top_k_indices = similarities[0].argsort()[-k:][::-1] + + # 杩斿洖鏂囨湰鍧楀拰瀹冧滑鐨勫悜閲� + top_k_texts = [text_blocks[i] for i in top_k_indices] + top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices] + + return top_k_texts, top_k_embeddings + +@app.route('/vectorize', methods=['POST']) +def vectorize_text(): + # 浠庤姹備腑鑾峰彇 JSON 鏁版嵁 + data = request.json + print(f"Received request data: {data}") # 璋冭瘯杈撳嚭璇锋眰鏁版嵁 + + text_list = data.get("text", []) + model_name = data.get("model_name", "msmarco-distilbert-base-tas-b") # 榛樿妯″瀷 + + delimiter = data.get("delimiter", "\n") # 榛樿鍒嗛殧绗� + k = int(data.get("k", 3)) # 榛樿妫�绱㈡潯鏁� + block_size = int(data.get("block_size", 500)) # 榛樿鏂囨湰鍧楀ぇ灏� + overlap_chars = int(data.get("overlap_chars", 50)) # 榛樿閲嶅彔瀛楃鏁� + + if not text_list: + return jsonify({"error": "Text is required."}), 400 + + # 妫�鏌ユā鍨嬫槸鍚﹀凡缁忓姞杞� + if model_name not in model_cache: + try: + model = SentenceTransformer(model_name) + model_cache[model_name] = model # 缂撳瓨妯″瀷 + except Exception as e: + return jsonify({"error": f"Failed to load model: {e}"}), 500 + + model = model_cache[model_name] + + top_k_texts_all = [] + top_k_embeddings_all = [] + + # 濡傛灉鍙湁涓�涓煡璇㈡枃鏈� + if len(text_list) == 1: + top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model) + top_k_texts_all.append(top_k_texts) + top_k_embeddings_all.append(top_k_embeddings) + elif len(text_list) > 1: + # 濡傛灉澶氫釜鏌ヨ鏂囨湰锛屼緷娆″鐞� + for query in text_list: + top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model) + top_k_texts_all.append(top_k_texts) + top_k_embeddings_all.append(top_k_embeddings) + + # 灏嗗祵鍏ュ悜閲忥紙ndarray锛夎浆鎹负鍙簭鍒楀寲鐨勫垪琛� + top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all] + + print(f"Top K texts: {top_k_texts_all}") # 鎵撳嵃妫�绱㈠埌鐨勬枃鏈� + print(f"Top K embeddings: {top_k_embeddings_all}") # 鎵撳嵃妫�绱㈠埌鐨勫悜閲� + + # 杩斿洖 JSON 鏍煎紡鐨勬暟鎹� + return jsonify({ + + "topKEmbeddings": top_k_embeddings_all # 杩斿洖宓屽叆鍚戦噺 + }) + +if __name__ == '__main__': + app.run(host="0.0.0.0", port=5000, debug=True) diff --git a/script/docker/localModels/requirements.txt b/script/docker/localModels/requirements.txt new file mode 100644 index 0000000..c1e1b50 --- /dev/null +++ b/script/docker/localModels/requirements.txt @@ -0,0 +1,3 @@ +Flask==2.0.3 +sentence-transformers==2.2.0 +scikit-learn==0.24.2 -- Gitblit v1.9.3