Source code for datasketch.lshforest

from collections import defaultdict


[docs]class MinHashLSHForest(object): ''' The LSH Forest for MinHash. It supports top-k query in Jaccard similarity. Instead of using prefix trees as the `original paper <http://ilpubs.stanford.edu:8090/678/1/2005-14.pdf>`_, I use a sorted array to store the hash values in every hash table. Args: num_perm (int, optional): The number of permutation functions used by the MinHash to be indexed. For weighted MinHash, this is the sample size (`sample_size`). l (int, optional): The number of prefix trees as described in the paper. Note: The MinHash LSH Forest also works with weighted Jaccard similarity and weighted MinHash without modification. '''
[docs] def __init__(self, num_perm=128, l=8): if l <= 0 or num_perm <= 0: raise ValueError("num_perm and l must be positive") if l > num_perm: raise ValueError("l cannot be greater than num_perm") # Number of prefix trees self.l = l # Maximum depth of the prefix tree self.k = int(num_perm / l) self.hashtables = [defaultdict(list) for _ in range(self.l)] self.hashranges = [(i*self.k, (i+1)*self.k) for i in range(self.l)] self.keys = dict() # This is the sorted array implementation for the prefix trees self.sorted_hashtables = [[] for _ in range(self.l)]
[docs] def add(self, key, minhash): ''' Add a unique key, together with a MinHash (or weighted MinHash) of the set referenced by the key. Note: The key won't be searchbale until the :func:`datasketch.MinHashLSHForest.index` method is called. Args: key (hashable): The unique identifier of the set. minhash (datasketch.MinHash): The MinHash of the set. ''' if len(minhash) < self.k*self.l: raise ValueError("The num_perm of MinHash out of range") if key in self.keys: raise ValueError("The given key has already been added") self.keys[key] = [self._H(minhash.hashvalues[start:end]) for start, end in self.hashranges] for H, hashtable in zip(self.keys[key], self.hashtables): hashtable[H].append(key)
[docs] def index(self): ''' Index all the keys added so far and make them searchable. ''' for i, hashtable in enumerate(self.hashtables): self.sorted_hashtables[i] = [H for H in hashtable.keys()] self.sorted_hashtables[i].sort()
def _query(self, minhash, r, b): if r > self.k or r <=0 or b > self.l or b <= 0: raise ValueError("parameter outside range") # Generate prefixes of concatenated hash values hps = [self._H(minhash.hashvalues[start:start+r]) for start, _ in self.hashranges] # Set the prefix length for look-ups in the sorted hash values list prefix_size = len(hps[0]) for ht, hp, hashtable in zip(self.sorted_hashtables, hps, self.hashtables): i = self._binary_search(len(ht), lambda x : ht[x][:prefix_size] >= hp) if i < len(ht) and ht[i][:prefix_size] == hp: j = i while j < len(ht) and ht[j][:prefix_size] == hp: for key in hashtable[ht[j]]: yield key j += 1
[docs] def query(self, minhash, k): ''' Return the approximate top-k keys that have the (approximately) highest Jaccard similarities to the query set. Args: minhash (datasketch.MinHash): The MinHash of the query set. k (int): The maximum number of keys to return. Returns: `list` of at most k keys. Note: Tip for improving accuracy: you can use a multiple of `k` (e.g., `2*k`) in the argument, compute the exact (or approximate using MinHash) Jaccard similarities of the sets referenced by the returned keys, from which you then take the final top-k. This is often called "post-processing". Because the total number of similarity computations is still bounded by a constant multiple of `k`, the performance won't degrade too much -- however you do have to keep the original sets (or MinHashes) around some where so that you can make references to them. ''' if k <= 0: raise ValueError("k must be positive") if len(minhash) < self.k*self.l: raise ValueError("The num_perm of MinHash out of range") results = set() r = self.k while r > 0: for key in self._query(minhash, r, self.l): results.add(key) if len(results) >= k: return list(results) r -= 1 return list(results)
def _binary_search(self, n, func): ''' https://golang.org/src/sort/search.go?s=2247:2287#L49 ''' i, j = 0, n while i < j: h = int(i + (j - i) / 2) if not func(h): i = h + 1 else: j = h return i
[docs] def is_empty(self): ''' Check whether there is any searchable keys in the index. Note that keys won't be searchable until `index` is called. Returns: bool: True if there is no searchable key in the index. ''' return any(len(t) == 0 for t in self.sorted_hashtables)
def _H(self, hs): return bytes(hs.byteswap().data)
[docs] def __contains__(self, key): ''' Returns: bool: True only if the key has been added to the index. ''' return key in self.keys