# -*- coding: utf-8 -*- # Copyright (c) Louis Brulé Naudet. All Rights Reserved. # This software may be used and distributed according to the terms of the License Agreement. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import faiss import numpy as np import torch from usearch.index import Index from sentence_transformers import SentenceTransformer from sentence_transformers.quantization import quantize_embeddings from typing import Tuple, List, Union class SimilaritySearch: """ A class dedicated to encoding text data, quantizing embeddings, and managing indices for efficient similarity search. Attributes ---------- model_name : str Name or identifier of the embedding model. device : str Computation device ('cpu' or 'cuda'). ndim : int Dimension of the embeddings. metric : str Metric used for the index ('ip' for inner product, etc.). dtype : str Data type for the index ('i8' for int8, etc.). Methods ------- encode(corpus, normalize_embeddings=True) Encodes a list of text data into embeddings. quantize_embeddings(embeddings, quantization_type) Quantizes the embeddings for efficient storage and search. create_faiss_index(ubinary_embeddings, index_path) Creates and saves a FAISS binary index. create_usearch_index(int8_embeddings, index_path) Creates and saves a USEARCH integer index. load_usearch_index_view(index_path) Loads a USEARCH index as a view for memory-efficient operations. load_faiss_index(index_path) Loads a FAISS binary index for searching. search(query, top_k=10, rescore_multiplier=4) Performs a search operation against the indexed embeddings. """ def __init__( self, model_name: str, device: str = "cuda", ndim: int = 1024, metric: str = "ip", dtype: str = "i8" ): """ Initializes the EmbeddingIndexer with the specified model, device, and index configurations. Parameters ---------- model_name : str The name or identifier of the SentenceTransformer model to use for embedding. device : str, optional The computation device to use ('cpu' or 'cuda'). Default is 'cuda'. ndim : int, optional The dimensionality of the embeddings. Default is 1024. metric : str, optional The metric used for the index ('ip' for inner product). Default is 'ip'. dtype : str, optional The data type for the USEARCH index ('i8' for 8-bit integer). Default is 'i8'. """ self.model_name = model_name self.device = device self.ndim = ndim self.metric = metric self.dtype = dtype self.model = SentenceTransformer( self.model_name, device=self.device ) self.binary_index = None self.int8_index = None def encode( self, corpus: list, normalize_embeddings: bool = True ) -> np.ndarray: """ Encodes the given corpus into full-precision embeddings. Parameters ---------- corpus : list A list of sentences to be encoded. normalize_embeddings : bool, optional Whether to normalize returned vectors to have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. Default is True. Returns ------- np.ndarray The full-precision embeddings of the corpus. Notes ----- This method normalizes the embeddings and shows the progress bar during the encoding process. """ try: embeddings = self.model.encode( corpus, normalize_embeddings=normalize_embeddings, show_progress_bar=True ) return embeddings except Exception as e: print(f"An error occurred during encoding: {e}") def quantize_embeddings( self, embeddings: np.ndarray, quantization_type: str ) -> Union[np.ndarray, bytearray]: """ Quantizes the given embeddings based on the specified quantization type ('ubinary' or 'int8'). Parameters ---------- embeddings : np.ndarray The full-precision embeddings to be quantized. quantization_type : str The type of quantization ('ubinary' for unsigned binary, 'int8' for 8-bit integers). Returns ------- Union[np.ndarray, bytearray] The quantized embeddings. Raises ------ ValueError If an unsupported quantization type is provided. """ try: if quantization_type == "ubinary": return self._quantize_to_ubinary( embeddings=embeddings ) elif quantization_type == "int8": return self._quantize_to_int8( embeddings=embeddings ) else: raise ValueError(f"Unsupported quantization type: {quantization_type}") except Exception as e: print(f"An error occurred during quantization: {e}") def create_faiss_index( self, ubinary_embeddings: bytearray, index_path: str = None, save: bool = False ) -> None: """ Creates and saves a FAISS binary index from ubinary embeddings. Parameters ---------- ubinary_embeddings : bytearray The ubinary-quantized embeddings. index_path : str, optional The file path to save the FAISS binary index. Default is None. save : bool, optional Indicator for saving the index. Default is False. Notes ----- The dimensionality of the index is specified during the class initialization (default is 1024). """ try: self.binary_index = faiss.IndexBinaryFlat( self.ndim ) self.binary_index.add( ubinary_embeddings ) if save and index_path: self._save_faiss_index_binary( index_path=index_path ) except Exception as e: print(f"An error occurred during index creation: {e}") def create_usearch_index( self, int8_embeddings: np.ndarray, index_path: str = None, save: bool = False ) -> None: """ Creates and saves a USEARCH integer index from int8 embeddings. Parameters ---------- int8_embeddings : np.ndarray The int8-quantized embeddings. index_path : str, optional The file path to save the USEARCH integer index. Default is None. save : bool, optional Indicator for saving the index. Default is False. Returns ------- None Notes ----- The dimensionality and metric of the index are specified during class initialization. """ try: self.int8_index = Index( ndim=self.ndim, metric=self.metric, dtype=self.dtype ) self.int8_index.add( np.arange( len(int8_embeddings) ), int8_embeddings ) if save == True and index_path: self._save_int8_index( index_path=index_path ) return self.int8_index except Exception as e: print(f"An error occurred during USEARCH index creation: {e}") def load_usearch_index_view( self, index_path: str ) -> any: """ Loads a USEARCH index as a view for memory-efficient operations. Parameters ---------- index_path : str The file path to the USEARCH index to be loaded as a view. Returns ------- object A view of the USEARCH index for memory-efficient similarity search operations. Notes ----- Implementing this would depend on the specific USEARCH index handling library being used. """ try: self.int8_index = Index.restore( index_path, view=True ) return self.int8_index except Exception as e: print(f"An error occurred while loading USEARCH index: {e}") def load_faiss_index( self, index_path: str ) -> None: """ Loads a FAISS binary index from a specified file path. This method loads a binary index created by FAISS into the class attribute `binary_index`, ready for performing similarity searches. Parameters ---------- index_path : str The file path to the saved FAISS binary index. Returns ------- None Notes ----- The loaded index is stored in the `binary_index` attribute of the class. Ensure that the index at `index_path` is compatible with the configurations (e.g., dimensions) used for this class instance. """ try: self.binary_index = faiss.read_index_binary( index_path ) except Exception as e: print(f"An error occurred while loading the FAISS index: {e}") def search( self, query: str, top_k: int = 10, rescore_multiplier: int = 4 ) -> Tuple[List[float], List[int]]: """ Performs a search operation against the indexed embeddings. Parameters ---------- query : str The query sentence/string to be searched. top_k : int, optional The number of top results to return. rescore_multiplier : int, optional The multiplier used to increase the initial retrieval size for re-scoring. Higher values can increase precision at the cost of performance. Returns ------- Tuple[List[float], List[int]] A tuple containing the scores and the indices of the top k results. Notes ----- This method assumes that `binary_index` and `int8_index` are already loaded or created. """ try: if self.binary_index is None or self.int8_index is None: raise ValueError("Indices must be loaded or created before searching.") query_embedding = self.encode( corpus=query, normalize_embeddings=False ) query_embedding_ubinary = self.quantize_embeddings( embeddings=query_embedding.reshape(1, -1), quantization_type="ubinary" ) _scores, binary_ids = self.binary_index.search( query_embedding_ubinary, top_k * rescore_multiplier ) binary_ids = binary_ids[0] int8_embeddings = self.int8_index[binary_ids].astype(int) scores = query_embedding @ int8_embeddings.T indices = (-scores).argsort()[:top_k] top_k_indices = binary_ids[indices] top_k_scores = scores[indices] return top_k_scores.tolist(), top_k_indices.tolist() except Exception as e: print(f"An error occurred while searching semantic similar sentences: {e}") def _quantize_to_ubinary( self, embeddings: np.ndarray ) -> np.ndarray: """ Placeholder private method for ubinary quantization. Parameters ---------- embeddings : np.ndarray The embeddings to quantize. Returns ------- np.ndarray The quantized embeddings. """ try: ubinary_embeddings = quantize_embeddings( embeddings, "ubinary" ) return ubinary_embeddings except Exception as e: print(f"An error occurred during ubinary quantization: {e}") def _quantize_to_int8( self, embeddings: np.ndarray ) -> np.ndarray: """ Placeholder private method for int8 quantization. Parameters ---------- embeddings : np.ndarray The embeddings to quantize. Returns ------- np.ndarray The quantized embeddings. """ try: int8_embeddings = quantize_embeddings( embeddings, "int8" ) return int8_embeddings except Exception as e: print(f"An error occurred during int8 quantization: {e}") def _save_faiss_index_binary( self, index_path: str ) -> None: """ Saves the FAISS binary index to disk. This private method is called internally to save the constructed FAISS binary index to the specified file path. Parameters ---------- index_path : str The path to the file where the binary index should be saved. This value is checked in the public method `create_faiss_index`. Returns ------- None Notes ----- This method should not be called directly. It is intended to be used internally by the `create_faiss_index` method. """ try: faiss.write_index_binary( self.binary_index, index_path ) return None except Exception as e: print(f"An error occurred during FAISS binary index saving: {e}") def _save_int8_index( self, index_path: str ) -> None: """ Saves the int8_index to disk. This private method is called internally to save the constructed int8_index to the specified file path. Parameters ---------- index_path : str The path to the file where the int8_index should be saved. This value is checked in the public method `_save_int8_index`. Returns ------- None Notes ----- This method should not be called directly. It is intended to be used internally by the `_save_int8_index` method. """ try: self.int8_index.save( index_path ) return None except Exception as e: print(f"An error occurred during int8_index saving: {e}")