¶Ô±ÈÐÂÎļþ |
| | |
| | | from flask import Flask, request, jsonify |
| | | from sentence_transformers import SentenceTransformer |
| | | from sklearn.metrics.pairwise import cosine_similarity |
| | | import json |
| | | |
| | | app = Flask(__name__) |
| | | |
| | | # å建ä¸ä¸ªå
¨å±ç模åç¼ååå
¸ |
| | | model_cache = {} |
| | | |
| | | # å岿æ¬å |
| | | def split_text(text, block_size, overlap_chars, delimiter): |
| | | chunks = text.split(delimiter) |
| | | text_blocks = [] |
| | | current_block = "" |
| | | |
| | | for chunk in chunks: |
| | | if len(current_block) + len(chunk) + 1 <= block_size: |
| | | if current_block: |
| | | current_block += " " + chunk |
| | | else: |
| | | current_block = chunk |
| | | else: |
| | | text_blocks.append(current_block) |
| | | current_block = chunk |
| | | if current_block: |
| | | text_blocks.append(current_block) |
| | | |
| | | overlap_blocks = [] |
| | | for i in range(len(text_blocks)): |
| | | if i > 0: |
| | | overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i] |
| | | overlap_blocks.append(overlap_block) |
| | | overlap_blocks.append(text_blocks[i]) |
| | | |
| | | return overlap_blocks |
| | | |
| | | # ææ¬åéå |
| | | def vectorize_text_blocks(text_blocks, model): |
| | | return model.encode(text_blocks) |
| | | |
| | | # ææ¬æ£ç´¢ |
| | | def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model): |
| | | # å°ç¥è¯åºæåä¸ºææ¬å |
| | | text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter) |
| | | # åéåææ¬å |
| | | knowledge_vectors = vectorize_text_blocks(text_blocks, model) |
| | | # åéåæ¥è¯¢ææ¬ |
| | | query_vector = model.encode([query]).reshape(1, -1) |
| | | # 计ç®ç¸ä¼¼åº¦ |
| | | similarities = cosine_similarity(query_vector, knowledge_vectors) |
| | | # è·åç¸ä¼¼åº¦æé«ç k ä¸ªææ¬åçç´¢å¼ |
| | | top_k_indices = similarities[0].argsort()[-k:][::-1] |
| | | |
| | | # è¿åææ¬ååå®ä»¬çåé |
| | | top_k_texts = [text_blocks[i] for i in top_k_indices] |
| | | top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices] |
| | | |
| | | return top_k_texts, top_k_embeddings |
| | | |
| | | @app.route('/vectorize', methods=['POST']) |
| | | def vectorize_text(): |
| | | # ä»è¯·æ±ä¸è·å JSON æ°æ® |
| | | data = request.json |
| | | print(f"Received request data: {data}") # è°è¯è¾åºè¯·æ±æ°æ® |
| | | |
| | | text_list = data.get("text", []) |
| | | model_name = data.get("model_name", "msmarco-distilbert-base-tas-b") # é»è®¤æ¨¡å |
| | | |
| | | delimiter = data.get("delimiter", "\n") # é»è®¤åé符 |
| | | k = int(data.get("k", 3)) # é»è®¤æ£ç´¢æ¡æ° |
| | | block_size = int(data.get("block_size", 500)) # é»è®¤ææ¬åå¤§å° |
| | | overlap_chars = int(data.get("overlap_chars", 50)) # é»è®¤éå åç¬¦æ° |
| | | |
| | | if not text_list: |
| | | return jsonify({"error": "Text is required."}), 400 |
| | | |
| | | # æ£æ¥æ¨¡åæ¯å¦å·²ç»å è½½ |
| | | if model_name not in model_cache: |
| | | try: |
| | | model = SentenceTransformer(model_name) |
| | | model_cache[model_name] = model # ç¼å模å |
| | | except Exception as e: |
| | | return jsonify({"error": f"Failed to load model: {e}"}), 500 |
| | | |
| | | model = model_cache[model_name] |
| | | |
| | | top_k_texts_all = [] |
| | | top_k_embeddings_all = [] |
| | | |
| | | # å¦æåªæä¸ä¸ªæ¥è¯¢ææ¬ |
| | | if len(text_list) == 1: |
| | | top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model) |
| | | top_k_texts_all.append(top_k_texts) |
| | | top_k_embeddings_all.append(top_k_embeddings) |
| | | elif len(text_list) > 1: |
| | | # 妿å¤ä¸ªæ¥è¯¢ææ¬ï¼ä¾æ¬¡å¤ç |
| | | for query in text_list: |
| | | top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model) |
| | | top_k_texts_all.append(top_k_texts) |
| | | top_k_embeddings_all.append(top_k_embeddings) |
| | | |
| | | # å°åµå
¥åéï¼ndarrayï¼è½¬æ¢ä¸ºå¯åºååçå表 |
| | | top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all] |
| | | |
| | | print(f"Top K texts: {top_k_texts_all}") # æå°æ£ç´¢å°çææ¬ |
| | | print(f"Top K embeddings: {top_k_embeddings_all}") # æå°æ£ç´¢å°çåé |
| | | |
| | | # è¿å JSON æ ¼å¼çæ°æ® |
| | | return jsonify({ |
| | | |
| | | "topKEmbeddings": top_k_embeddings_all # è¿ååµå
¥åé |
| | | }) |
| | | |
| | | if __name__ == '__main__': |
| | | app.run(host="0.0.0.0", port=5000, debug=True) |