package org.ruoyi.common.chat.localModels; import io.micrometer.common.util.StringUtils; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; import org.springframework.stereotype.Service; import retrofit2.Call; import retrofit2.Callback; import retrofit2.Response; import retrofit2.Retrofit; import retrofit2.converter.jackson.JacksonConverterFactory; import java.util.List; import java.util.concurrent.CountDownLatch; @Slf4j @Service public class LocalModelsofitClient { private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL private static Retrofit retrofit = null; // 获取 Retrofit 实例 public static Retrofit getRetrofitInstance() { if (retrofit == null) { OkHttpClient client = new OkHttpClient.Builder() .build(); retrofit = new Retrofit.Builder() .baseUrl(BASE_URL) .client(client) .addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换 .build(); } return retrofit; } /** * 向 Flask 服务发送文本向量化请求 * * @param queries 查询文本列表 * @param modelName 模型名称 * @param delimiter 文本分隔符 * @param topK 返回的结果数 * @param blockSize 文本块大小 * @param overlapChars 重叠字符数 * @return 返回计算得到的 Top K 嵌入向量列表 */ public static List> getTopKEmbeddings( List queries, String modelName, String delimiter, int topK, int blockSize, int overlapChars) { modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称 delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // 默认分隔符 topK = (topK > 0) ? topK : 3; // 默认返回 3 个结果 blockSize = (blockSize > 0) ? blockSize : 500; // 默认文本块大小为 500 overlapChars = (overlapChars > 0) ? overlapChars : 50; // 默认重叠字符数为 50 // 创建 Retrofit 实例 Retrofit retrofit = getRetrofitInstance(); // 创建 SearchService 接口 SearchService service = retrofit.create(SearchService.class); // 创建请求对象 LocalModelsSearchRequest LocalModelsSearchRequest request = new LocalModelsSearchRequest( queries, // 查询文本列表 modelName, // 模型名称 delimiter, // 文本分隔符 topK, // 返回的结果数 blockSize, // 文本块大小 overlapChars // 重叠字符数 ); final CountDownLatch latch = new CountDownLatch(1); // 创建一个 CountDownLatch final List>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List) // 发起异步请求 service.vectorize(request).enqueue(new Callback() { @Override public void onResponse(Call call, Response response) { if (response.isSuccessful()) { LocalModelsSearchResponse searchResponse = response.body(); if (searchResponse != null) { topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // 获取结果 log.info("Successfully retrieved embeddings"); } else { log.error("Response body is null"); } } else { log.error("Request failed. HTTP error code: " + response.code()); } latch.countDown(); // 请求完成,减少计数 } @Override public void onFailure(Call call, Throwable t) { t.printStackTrace(); log.error("Request failed: ", t); latch.countDown(); // 请求失败,减少计数 } }); try { latch.await(); // 等待请求完成 } catch (InterruptedException e) { e.printStackTrace(); } return topKEmbeddings[0]; // 返回结果 } // public static void main(String[] args) { // // 示例调用 // List queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries."); // String modelName = "msmarco-distilbert-base-tas-b"; // String delimiter = "."; // int topK = 3; // int blockSize = 500; // int overlapChars = 50; // // List> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars); // // // 打印结果 // if (topKEmbeddings != null) { // System.out.println("Top K embeddings: "); // for (List embedding : topKEmbeddings) { // System.out.println(embedding); // } // } else { // System.out.println("No embeddings returned."); // } // } // public static void main(String[] args) { // // 创建 Retrofit 实例 // Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance(); // // // 创建 SearchService 接口 // SearchService service = retrofit.create(SearchService.class); // // // 创建请求对象 LocalModelsSearchRequest // LocalModelsSearchRequest request = new LocalModelsSearchRequest( // Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表 // "msmarco-distilbert-base-tas-b", // 模型名称 // ".", // 分隔符 // 3, // 返回的结果数 // 500, // 文本块大小 // 50 // 重叠字符数 // ); // // // 发起请求 // service.vectorize(request).enqueue(new Callback() { // @Override // public void onResponse(Call call, Response response) { // if (response.isSuccessful()) { // LocalModelsSearchResponse searchResponse = response.body(); // System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging // // if (searchResponse != null) { // // If the response is not null, process it. // // Example: Extract the embeddings and print them // List>> topKEmbeddings = searchResponse.getTopKEmbeddings(); // if (topKEmbeddings != null) { // // Print the Top K embeddings // // } else { // System.err.println("Top K embeddings are null"); // } // // // If there is more information you want to process, handle it here // // } else { // System.err.println("Response body is null"); // } // } else { // System.err.println("Request failed. HTTP error code: " + response.code()); // log.error("Failed to retrieve data. HTTP error code: " + response.code()); // } // } // // @Override // public void onFailure(Call call, Throwable t) { // // 请求失败,打印错误 // t.printStackTrace(); // log.error("Request failed: ", t); // } // }); // } }