/* * Copyright 2023 - 2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.ai.vectorstore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import redis.clients.jedis.JedisPooled; import redis.clients.jedis.Pipeline; import redis.clients.jedis.json.Path2; import redis.clients.jedis.search.*; import redis.clients.jedis.search.Schema.FieldType; import redis.clients.jedis.search.schemafields.*; import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; import java.text.MessageFormat; import java.util.*; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; /** * The RedisVectorStore is for managing and querying vector data in a Redis database. It * offers functionalities like adding, deleting, and performing similarity searches on * documents. * * The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and * search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for * efficient similarity searches. Additionally, it allows for custom metadata fields in * the documents to be stored alongside the vector and content data. * * This class requires a RedisVectorStoreConfig configuration object for initialization, * which includes settings like Redis URI, index name, field names, and vector algorithms. * It also requires an EmbeddingModel to convert documents into embeddings before storing * them. * * @author Julien Ruaux * @author Christian Tzolov * @author EddĂș MelĂ©ndez * @see VectorStore * @see RedisVectorStoreConfig * @see EmbeddingModel */ public class RedisVectorStore implements VectorStore, InitializingBean { public enum Algorithm { FLAT, HSNW } public record MetadataField(String name, FieldType fieldType) { public static MetadataField text(String name) { return new MetadataField(name, FieldType.TEXT); } public static MetadataField numeric(String name) { return new MetadataField(name, FieldType.NUMERIC); } public static MetadataField tag(String name) { return new MetadataField(name, FieldType.TAG); } } /** * Configuration for the Redis vector store. */ public static final class RedisVectorStoreConfig { private final String indexName; private final String prefix; private final String contentFieldName; private final String embeddingFieldName; private final Algorithm vectorAlgorithm; private final List metadataFields; private RedisVectorStoreConfig() { this(builder()); } private RedisVectorStoreConfig(Builder builder) { this.indexName = builder.indexName; this.prefix = builder.prefix; this.contentFieldName = builder.contentFieldName; this.embeddingFieldName = builder.embeddingFieldName; this.vectorAlgorithm = builder.vectorAlgorithm; this.metadataFields = builder.metadataFields; } /** * Start building a new configuration. * @return The entry point for creating a new configuration. */ public static Builder builder() { return new Builder(); } /** * {@return the default config} */ public static RedisVectorStoreConfig defaultConfig() { return builder().build(); } public static class Builder { private String indexName = DEFAULT_INDEX_NAME; private String prefix = DEFAULT_PREFIX; private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; private List metadataFields = new ArrayList<>(); private Builder() { } /** * Configures the Redis index name to use. * @param name the index name to use * @return this builder */ public Builder withIndexName(String name) { this.indexName = name; return this; } /** * Configures the Redis key prefix to use (default: "embedding:"). * @param prefix the prefix to use * @return this builder */ public Builder withPrefix(String prefix) { this.prefix = prefix; return this; } /** * Configures the Redis content field name to use. * @param name the content field name to use * @return this builder */ public Builder withContentFieldName(String name) { this.contentFieldName = name; return this; } /** * Configures the Redis embedding field name to use. * @param name the embedding field name to use * @return this builder */ public Builder withEmbeddingFieldName(String name) { this.embeddingFieldName = name; return this; } /** * Configures the Redis vector algorithmto use. * @param algorithm the vector algorithm to use * @return this builder */ public Builder withVectorAlgorithm(Algorithm algorithm) { this.vectorAlgorithm = algorithm; return this; } public Builder withMetadataFields(MetadataField... fields) { return withMetadataFields(Arrays.asList(fields)); } public Builder withMetadataFields(List fields) { this.metadataFields = fields; return this; } /** * {@return the immutable configuration} */ public RedisVectorStoreConfig build() { return new RedisVectorStoreConfig(this); } } } private final boolean initializeSchema; public static final String DEFAULT_INDEX_NAME = "spring-ai-index"; public static final String DEFAULT_CONTENT_FIELD_NAME = "content"; public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding"; public static final String DEFAULT_PREFIX = "embedding:"; public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW; private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]"; private static final Path2 JSON_SET_PATH = Path2.of("$"); private static final String JSON_PATH_PREFIX = "$."; private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class); private static final Predicate RESPONSE_OK = Predicate.isEqual("OK"); private static final Predicate RESPONSE_DEL_OK = Predicate.isEqual(1l); private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32"; private static final String EMBEDDING_PARAM_NAME = "BLOB"; public static final String DISTANCE_FIELD_NAME = "vector_score"; private static final String DEFAULT_DISTANCE_METRIC = "COSINE"; private final JedisPooled jedis; private final EmbeddingModel embeddingModel; private final RedisVectorStoreConfig config; private FilterExpressionConverter filterExpressionConverter; public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema) { Assert.notNull(config, "Config must not be null"); Assert.notNull(embeddingModel, "Embedding model must not be null"); this.initializeSchema = initializeSchema; this.jedis = jedis; this.embeddingModel = embeddingModel; this.config = config; this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields); } public JedisPooled getJedis() { return this.jedis; } @Override public void add(List documents) { try (Pipeline pipeline = this.jedis.pipelined()) { for (Document document : documents) { var embedding = this.embeddingModel.embed(document); document.setEmbedding(embedding); var fields = new HashMap(); fields.put(this.config.embeddingFieldName, embedding); fields.put(this.config.contentFieldName, document.getContent()); fields.putAll(document.getMetadata()); pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields); } List responses = pipeline.syncAndReturnAll(); Optional errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny(); if (errResponse.isPresent()) { String message = MessageFormat.format("Could not add document: {0}", errResponse.get()); if (logger.isErrorEnabled()) { logger.error(message); } throw new RuntimeException(message); } } } private String key(String id) { return this.config.prefix + id; } @Override public Optional delete(List idList) { try (Pipeline pipeline = this.jedis.pipelined()) { for (String id : idList) { pipeline.jsonDel(key(id)); } List responses = pipeline.syncAndReturnAll(); Optional errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny(); if (errResponse.isPresent()) { if (logger.isErrorEnabled()) { logger.error("Could not delete document: {}", errResponse.get()); } return Optional.of(false); } return Optional.of(true); } } @Override public List similaritySearch(SearchRequest request) { Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero"); Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1, "The similarity score is bounded between 0 and 1; least to most similar respectively."); String filter = nativeExpressionFilter(request); String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName, EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME); List returnFields = new ArrayList<>(); this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add); returnFields.add(this.config.embeddingFieldName); returnFields.add(this.config.contentFieldName); returnFields.add(DISTANCE_FIELD_NAME); var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery())); Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding)) .returnFields(returnFields.toArray(new String[0])) .setSortBy(DISTANCE_FIELD_NAME, true) .dialect(2); SearchResult result = this.jedis.ftSearch(this.config.indexName, query); return result.getDocuments() .stream() .filter(d -> similarityScore(d) >= request.getSimilarityThreshold()) .map(this::toDocument) .toList(); } private Document toDocument(redis.clients.jedis.search.Document doc) { var id = doc.getId().substring(this.config.prefix.length()); var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) : null; Map metadata = this.config.metadataFields.stream() .map(MetadataField::name) .filter(doc::hasProperty) .collect(Collectors.toMap(Function.identity(), doc::getString)); metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc)); return new Document(id, content, metadata); } private float similarityScore(redis.clients.jedis.search.Document doc) { return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2; } private String nativeExpressionFilter(SearchRequest request) { if (request.getFilterExpression() == null) { return "*"; } return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")"; } @Override public void afterPropertiesSet() { if (!this.initializeSchema) { return; } // If index already exists don't do anything if (this.jedis.ftList().contains(this.config.indexName)) { return; } String response = this.jedis.ftCreate(this.config.indexName, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields()); if (!RESPONSE_OK.test(response)) { String message = MessageFormat.format("Could not create index: {0}", response); throw new RuntimeException(message); } } private Iterable schemaFields() { Map vectorAttrs = new HashMap<>(); vectorAttrs.put("DIM", this.embeddingModel.dimensions()); vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC); vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32); List fields = new ArrayList<>(); fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0)); fields.add(VectorField.builder() .fieldName(jsonPath(this.config.embeddingFieldName)) .algorithm(vectorAlgorithm()) .attributes(vectorAttrs) .as(this.config.embeddingFieldName) .build()); if (!CollectionUtils.isEmpty(this.config.metadataFields)) { for (MetadataField field : this.config.metadataFields) { fields.add(schemaField(field)); } } return fields; } private SchemaField schemaField(MetadataField field) { String fieldName = jsonPath(field.name); switch (field.fieldType) { case NUMERIC: return NumericField.of(fieldName).as(field.name); case TAG: return TagField.of(fieldName).as(field.name); case TEXT: return TextField.of(fieldName).as(field.name); default: throw new IllegalArgumentException( MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType)); } } private VectorAlgorithm vectorAlgorithm() { if (config.vectorAlgorithm == Algorithm.HSNW) { return VectorAlgorithm.HNSW; } return VectorAlgorithm.FLAT; } private String jsonPath(String field) { return JSON_PATH_PREFIX + field; } private static float[] toFloatArray(List embeddingDouble) { float[] embeddingFloat = new float[embeddingDouble.size()]; int i = 0; for (Double d : embeddingDouble) { embeddingFloat[i++] = d.floatValue(); } return embeddingFloat; } }