from __future__ import annotations
from collections import OrderedDict
import heapq
from typing import (
Hashable,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
Union,
)
import numpy as np
class _Layer(object):
"""A graph layer in the HNSW index. This is a dictionary-like object
that maps a key to a dictionary of neighbors.
Args:
key (Hashable): The first key to insert into the graph.
"""
def __init__(self, key: Hashable) -> None:
# self._graph[key] contains a {j: dist} dictionary,
# where j is a neighbor of key and dist is distance.
self._graph: Dict[Hashable, Dict[Hashable, float]] = {key: {}}
def __contains__(self, key: Hashable) -> bool:
return key in self._graph
def __getitem__(self, key: Hashable) -> Dict[Hashable, float]:
return self._graph[key]
def __setitem__(self, key: Hashable, value: Dict[Hashable, float]) -> None:
self._graph[key] = value
def __delitem__(self, key: Hashable) -> None:
del self._graph[key]
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, _Layer):
return False
return self._graph == __value._graph
def __len__(self) -> int:
return len(self._graph)
def __iter__(self) -> Iterable[Hashable]:
return iter(self._graph)
def copy(self) -> _Layer:
"""Create a copy of the layer."""
new_layer = _Layer(None)
new_layer._graph = {k: dict(v) for k, v in self._graph.items()}
return new_layer
def get_reverse_edges(self, key: Hashable) -> Set[Hashable]:
reverse_edges = set()
for neighbor, neighbors in self._graph.items():
if key in neighbors:
reverse_edges.add(neighbor)
return reverse_edges
class _LayerWithReversedEdges(_Layer):
"""A graph layer in the HNSW index that also maintains reverse edges.
Args:
key (Hashable): The first key to insert into the graph.
"""
def __init__(self, key: Hashable) -> None:
# self._graph[key] contains a {j: dist} dictionary,
# where j is a neighbor of key and dist is distance.
self._graph: Dict[Hashable, Dict[Hashable, float]] = {key: {}}
# self._reverse_edges[key] contains a set of neighbors of key.
self._reverse_edges: Dict[Hashable, Set] = {}
def __setitem__(self, key: Hashable, value: Dict[Hashable, float]) -> None:
old_neighbors = self._graph.get(key, {})
self._graph[key] = value
for neighbor in old_neighbors:
self._reverse_edges[neighbor].discard(key)
for neighbor in value:
self._reverse_edges.setdefault(neighbor, set()).add(key)
if key not in self._reverse_edges:
self._reverse_edges[key] = set()
def __delitem__(self, key: Hashable) -> None:
old_neighbors = self._graph.get(key, {})
del self._graph[key]
for neighbor in old_neighbors:
self._reverse_edges[neighbor].discard(key)
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, _LayerWithReversedEdges):
return False
return (
self._graph == __value._graph
and self._reverse_edges == __value._reverse_edges
)
def __len__(self) -> int:
return len(self._graph)
def __iter__(self) -> Iterable[Hashable]:
return iter(self._graph)
def copy(self) -> _LayerWithReversedEdges:
"""Create a copy of the layer."""
new_layer = _LayerWithReversedEdges(None)
new_layer._graph = {k: dict(v) for k, v in self._graph.items()}
new_layer._reverse_edges = {k: set(v) for k, v in self._reverse_edges.items()}
return new_layer
def get_reverse_edges(self, key: Hashable) -> Set[Hashable]:
return self._reverse_edges[key]
class _Node(object):
"""A node in the HNSW graph."""
def __init__(self, key: Hashable, point: np.ndarray, is_deleted=False) -> None:
self.key = key
self.point = point
self.is_deleted = is_deleted
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, _Node):
return False
return (
self.key == __value.key
and np.array_equal(self.point, __value.point)
and self.is_deleted == __value.is_deleted
)
def __hash__(self) -> int:
return hash(self.key)
def __repr__(self) -> str:
return (
f"_Node(key={self.key}, point={self.point}, is_deleted={self.is_deleted})"
)
def copy(self) -> _Node:
return _Node(self.key, self.point, self.is_deleted)
[docs]
class HNSW(MutableMapping):
"""Hierarchical Navigable Small World (HNSW) graph index for approximate
nearest neighbor search. This implementation is based on the paper
"Efficient and robust approximate nearest neighbor search using Hierarchical
Navigable Small World graphs" by Yu. A. Malkov, D. A. Yashunin (2016),
`<https://arxiv.org/abs/1603.09320>`_.
Args:
distance_func: A function that takes two vectors and returns a float
representing the distance between them.
m (int): The number of neighbors to keep for each node.
ef_construction (int): The number of neighbors to consider during
construction.
m0 (Optional[int]): The number of neighbors to keep for each node at
the 0th level. If None, defaults to 2 * m.
seed (Optional[int]): The random seed to use for the random number
generator.
reverse_edges (bool): Whether to maintain reverse edges in the graph.
This speeds up hard remove (:meth:`remove`) but increases memory
usage and slows down :meth:`insert`.
Examples:
Create an HNSW index with Euclidean distance and insert 1000 random
vectors of dimension 10.
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data = np.random.random_sample((1000, 10))
index = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
for i, d in enumerate(data):
index.insert(i, d)
# Query the index for the 10 nearest neighbors of the first vector.
index.query(data[0], k=10)
Create an HNSW index with Jaccard distance and insert 1000 random
sets of 10 elements each.
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
# Each set is represented as a 10-element vector of random integers
# between 0 and 100.
# Deduplication is handled by the distance function.
data = np.random.randint(0, 100, size=(1000, 10))
jaccard_distance = lambda x, y: (
1.0 - float(len(np.intersect1d(x, y, assume_unique=False)))
/ float(len(np.union1d(x, y)))
)
index = HNSW(distance_func=jaccard_distance)
for i, d in enumerate(data):
index[i] = d
# Query the index for the 10 nearest neighbors of the first set.
index.query(data[0], k=10)
"""
[docs]
def __init__(
self,
distance_func: Callable[[np.ndarray, np.ndarray], float],
m: int = 16,
ef_construction: int = 200,
m0: Optional[int] = None,
seed: Optional[int] = None,
reversed_edges: bool = False,
) -> None:
self._nodes: OrderedDict[Hashable, _Node] = OrderedDict()
self._distance_func = distance_func
self._m = m
self._ef_construction = ef_construction
self._m0 = 2 * m if m0 is None else m0
self._level_mult = 1 / np.log(m)
self._graphs: List[_Layer] = []
self._entry_point = None
self._random = np.random.RandomState(seed)
self._layer_class = _LayerWithReversedEdges if reversed_edges else _Layer
[docs]
def __len__(self) -> int:
"""Return the number of points in the index excluding those
that were soft-removed."""
return sum(not node.is_deleted for node in self._nodes.values())
[docs]
def __contains__(self, key: Hashable) -> bool:
"""Return ``True`` if the index contains the key and it was
not soft-removed, else ``False``."""
return key in self._nodes and not self._nodes[key].is_deleted
[docs]
def __getitem__(self, key: Hashable) -> np.ndarray:
"""Get the point associated with the key. Raises KeyError if the key
does not exist in the index or it was soft-removed."""
if key not in self:
raise KeyError(key)
return self._nodes[key].point
[docs]
def __setitem__(self, key: Hashable, value: np.ndarray) -> None:
"""Set the point associated with the key and update the index.
This is equivalent to calling :meth:`insert` with the key and point."""
self.insert(key, value)
[docs]
def __delitem__(self, key: Hashable) -> None:
"""Soft remove the point associated with the key. Raises a KeyError if the
key does not exist in the index. This is equivalent to calling
:meth:`remove` with the key.
"""
self.remove(key)
[docs]
def __iter__(self) -> Iterator[Hashable]:
"""Return an iterator over the keys of the index that were not
soft-removed."""
return (key for key in self._nodes if not self._nodes[key].is_deleted)
[docs]
def reversed(self) -> Iterator[Hashable]:
"""Return a reverse iterator over the keys of the index that were not
soft-removed."""
return (key for key in reversed(self._nodes) if not self._nodes[key].is_deleted)
[docs]
def __eq__(self, __value: object) -> bool:
"""Return True only if the index parameters, random states, keys, points
points, and index structures are equal, including deleted points."""
if not isinstance(__value, HNSW):
return False
# Check if the index parameters are equal.
if (
self._distance_func != __value._distance_func
or self._m != __value._m
or self._ef_construction != __value._ef_construction
or self._m0 != __value._m0
or self._level_mult != __value._level_mult
or self._entry_point != __value._entry_point
):
return False
# Check if the random states are equal.
rand_state_1 = self._random.get_state()
rand_state_2 = __value._random.get_state()
for i in range(len(rand_state_1)):
if isinstance(rand_state_1[i], np.ndarray):
if not np.array_equal(rand_state_1[i], rand_state_2[i]):
return False
else:
if rand_state_1[i] != rand_state_2[i]:
return False
# Check if keys and points are equal.
# Note that deleted points are compared too.
return (
all(key in self._nodes for key in __value._nodes)
and all(key in __value._nodes for key in self._nodes)
and all(self._nodes[key] == __value._nodes[key] for key in self._nodes)
and self._graphs == __value._graphs
)
[docs]
def get(
self, key: Hashable, default: Optional[np.ndarray] = None
) -> Union[np.ndarray, None]:
"""Return the point for key in the index, else default. If default is not
given and key is not in the index or it was soft-removed, return None."""
if key not in self:
return default
return self._nodes[key].point
[docs]
def items(self) -> Iterator[Tuple[Hashable, np.ndarray]]:
"""Return an iterator of the indexed points that were not soft-removed
as (key, point) pairs."""
return (
(key, node.point)
for key, node in self._nodes.items()
if not node.is_deleted
)
[docs]
def keys(self) -> Iterator[Hashable]:
"""Return an iterator of the keys of the index points that were not
soft-removed."""
return (key for key in self._nodes if not self._nodes[key].is_deleted)
[docs]
def values(self) -> Iterator[np.ndarray]:
"""Return an iterator of the index points that were not soft-removed."""
return (node.point for node in self._nodes.values() if not node.is_deleted)
[docs]
def pop(
self, key: Hashable, default: Optional[np.ndarray] = None, hard: bool = False
) -> np.ndarray:
"""If key is in the index, remove it and return its associated point,
else return default. If default is not given and key is not in the index
or it was soft-removed, raise KeyError.
"""
if key not in self:
if default is None:
raise KeyError(key)
return default
point = self._nodes[key].point
self.remove(key, hard=hard)
return point
[docs]
def popitem(
self, last: bool = True, hard: bool = False
) -> Tuple[Hashable, np.ndarray]:
"""Remove and return a (key, point) pair from the index. Pairs are
returned in LIFO order if `last` is true or FIFO order if false.
If the index is empty or all points are soft-removed, raise KeyError.
Note:
In versions of Python before 3.7, the order of items in the index
is not guaranteed. This method will remove and return an arbitrary
(key, point) pair.
"""
if not self._nodes:
raise KeyError("popitem(): index is empty")
if last:
key = next(
(
key
for key in reversed(self._nodes)
if not self._nodes[key].is_deleted
),
None,
)
else:
key = next(
(key for key in self._nodes if not self._nodes[key].is_deleted), None
)
if key is None:
raise KeyError("popitem(): index is empty")
point = self._nodes[key].point
self.remove(key, hard=hard)
return key, point
[docs]
def clear(self) -> None:
"""Clear the index of all data points. This will not reset the random
number generator."""
self._nodes = {}
self._graphs = []
self._entry_point = None
[docs]
def copy(self) -> HNSW:
"""Create a copy of the index. The copy will have the same parameters
as the original index and the same keys and points, but will not share
any index data structures (i.e., graphs) with the original index.
The new index's random state will start from a copy of the original
index's."""
new_index = HNSW(
self._distance_func,
m=self._m,
ef_construction=self._ef_construction,
m0=self._m0,
)
new_index._nodes = OrderedDict(
(key, node.copy()) for key, node in self._nodes.items()
)
new_index._graphs = [layer.copy() for layer in self._graphs]
new_index._entry_point = self._entry_point
new_index._random.set_state(self._random.get_state())
return new_index
[docs]
def update(self, other: Union[Mapping, HNSW]) -> None:
"""Update the index with the points from the other Mapping or HNSW object,
overwriting existing keys.
Args:
other (Union[Mapping, HNSW]): The other Mapping or HNSW object.
Examples:
Create an HNSW index with a dictionary of points.
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data = np.random.random_sample((1000, 10))
index = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
# Batch insert 1000 points.
index.update({i: d for i, d in enumerate(data)})
Create an HNSW index with another HNSW index.
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data = np.random.random_sample((1000, 10))
index1 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
index2 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
# Batch insert 1000 points.
index1.update({i: d for i, d in enumerate(data)})
# Update index2 with the points from index1.
index2.update(index1)
"""
for key, point in other.items():
self.insert(key, point)
[docs]
def setdefault(self, key: Hashable, default: np.ndarray) -> np.ndarray:
"""If key is in the index and it was not soft-removed, return
its associated point. If not, insert
key with a value of default and return default. default cannot be None."""
if default is None:
raise ValueError("Default value cannot be None.")
if key not in self._nodes or self._nodes[key].is_deleted:
self.insert(key, default)
return self._nodes[key]
[docs]
def insert(
self,
key: Hashable,
new_point: np.ndarray,
ef: Optional[int] = None,
level: Optional[int] = None,
) -> None:
"""Add a new point to the index.
Args:
key (Hashable): The key of the new point. If the key already exists in the
index, the point will be updated and the index will be repaired.
new_point (np.ndarray): The new point to add to the index.
ef (Optional[int]): The number of neighbors to consider during insertion.
If None, use the construction ef.
level (Optional[int]): The level at which to insert the new point.
If None, the level will be chosen automatically.
"""
if ef is None:
ef = self._ef_construction
if key in self._nodes:
if self._nodes[key].is_deleted:
self._nodes[key].is_deleted = False
self._update(key, new_point, ef)
return
# level is the level at which we insert the element.
if level is None:
level = int(-np.log(self._random.random_sample()) * self._level_mult)
self._nodes[key] = _Node(key, new_point)
if (
self._entry_point is not None
): # The HNSW is not empty, we have an entry point
dist = self._distance_func(new_point, self._nodes[self._entry_point].point)
point = self._entry_point
# For all levels in which we dont have to insert elem,
# we search for the closest neighbor using greedy search.
for layer in reversed(self._graphs[level + 1 :]):
point, dist = self._search_ef1(
new_point, point, dist, layer, allow_soft_deleted=True
)
# Entry points for search at each level to insert.
entry_points = [(-dist, point)]
for layer in reversed(self._graphs[: level + 1]):
# Maximum number of neighbors to keep at this level.
level_m = self._m if layer is not self._graphs[0] else self._m0
# Search this layer for neighbors to insert, and update entry points
# for the next level.
entry_points = self._search_base_layer(
new_point, entry_points, layer, ef, allow_soft_deleted=True
)
# Insert the new node into the graph with out-going edges.
# We prune the out-going edges to keep only the top level_m neighbors.
layer[key] = {
p: d
for d, p in self._heuristic_prune(
[(-mdist, p) for mdist, p in entry_points], level_m
)
}
# For each neighbor of the new node, we insert the new node as a neighbor.
for neighbor_key, dist in layer[key].items():
layer[neighbor_key] = {
p: d
# We prune the edges to keep only the top level_m neighbors
# based on heuristic.
for d, p in self._heuristic_prune(
[(d, p) for p, d in layer[neighbor_key].items()]
+ [(dist, key)],
level_m,
)
}
# For all levels above the current level, we create an empty graph.
for _ in range(len(self._graphs), level + 1):
self._graphs.append(self._layer_class(key))
# We set the entry point for each new level to be the new node.
self._entry_point = key
def _update(self, key: Hashable, new_point: np.ndarray, ef: int) -> None:
"""Update the point associated with the key in the index.
Args:
key (Hashable): The key of the point.
new_point (np.ndarray): The new point to update to.
ef (int): The number of neighbors to consider during insertion.
Raises:
KeyError: If the key does not exist in the index.
"""
if key not in self._nodes:
raise KeyError(key)
# Update the point.
self._nodes[key].point = new_point
# If the entry point is the only point in the index, we do not need to
# update the index.
if self._entry_point == key and len(self._nodes) == 1:
return
for layer in self._graphs:
if key not in layer:
break
layer_m = self._m if layer is not self._graphs[0] else self._m0
# Create a set of points in the 2nd-degree neighborhood of the key.
neighborhood_keys = set([key])
for p in layer[key].keys():
neighborhood_keys.add(p)
for p2 in layer[p].keys():
neighborhood_keys.add(p2)
for p in layer[key].keys():
# For each neighbor of the key, we connects it with the top ef
# neighbors in the 2nd-degree neighborhood of the key.
cands = []
elem_to_keep = min(ef, len(neighborhood_keys) - 1)
for candidate_key in neighborhood_keys:
if candidate_key == p:
continue
dist = self._distance_func(
self._nodes[candidate_key].point, self._nodes[p].point
)
if len(cands) < elem_to_keep:
heapq.heappush(cands, (-dist, candidate_key))
elif dist < -cands[0][0]:
heapq.heappushpop(cands, (-dist, candidate_key))
layer[p] = {
p2: d2
for d2, p2 in self._heuristic_prune(
[(-md, p) for md, p in cands], layer_m
)
}
self._repair_connections(key, new_point, ef)
def _repair_connections(
self,
key: Hashable,
new_point: np.ndarray,
ef: int,
key_to_delete: Optional[Hashable] = None,
) -> None:
entry_point = self._entry_point
entry_point_dist = self._distance_func(
new_point, self._nodes[entry_point].point
)
entry_points = [(-entry_point_dist, entry_point)]
for layer in reversed(self._graphs):
if key not in layer:
# Greedy search for the closest neighbor from the highest layer down.
entry_point, entry_point_dist = self._search_ef1(
new_point,
entry_point,
entry_point_dist,
layer,
# We allow soft-deleted points to be returned and used as entry point.
allow_soft_deleted=True,
key_to_hard_delete=key_to_delete,
)
entry_points = [(-entry_point_dist, entry_point)]
else:
# Search for the neighbors at this layer using ef search.
level_m = self._m if layer is not self._graphs[0] else self._m0
entry_points = self._search_base_layer(
new_point,
entry_points,
layer,
ef + 1, # We add 1 to ef to account for the point itself.
# We allow soft-deleted points to be returned and used as entry point
# and neighbor candidates.
allow_soft_deleted=True,
key_to_hard_delete=key_to_delete,
)
# Filter out the updated node itself.
filtered_candidates = [(-md, p) for md, p in entry_points if p != key]
# Update the out-going edges of the updated node at this level.
layer[key] = {
p: d for d, p in self._heuristic_prune(filtered_candidates, level_m)
}
[docs]
def query(
self,
query_point: np.ndarray,
k: Optional[int] = None,
ef: Optional[int] = None,
) -> List[Tuple[Hashable, float]]:
"""Search for the k nearest neighbors of the query point.
Args:
query_point (np.ndarray): The query point.
k (Optional[int]): The number of neighbors to return. If None, return
all neighbors found.
ef (Optional[int]): The number of neighbors to consider during search.
If None, use the construction ef.
Returns:
List[Tuple[Hashable, float]]: A list of (key, distance) pairs for the k
nearest neighbors of the query point.
Raises:
ValueError: If the entry point is not found.
"""
if ef is None:
ef = self._ef_construction
if self._entry_point is None:
raise ValueError("Entry point not found.")
entry_point_dist = self._distance_func(
query_point, self._nodes[self._entry_point].point
)
entry_point = self._entry_point
# Search for the closest neighbor from the highest level to the 2nd
# level using greedy search.
for layer in reversed(self._graphs[1:]):
entry_point, entry_point_dist = self._search_ef1(
query_point, entry_point, entry_point_dist, layer
)
# Search for the neighbors at the base layer using ef search.
candidates = self._search_base_layer(
query_point, [(-entry_point_dist, entry_point)], self._graphs[0], ef
)
if k is not None:
# If k is specified, we return the k nearest neighbors.
candidates = heapq.nlargest(k, candidates)
else:
# Otherwise, we return all neighbors found.
candidates.sort(reverse=True)
# Return the neighbors as a list of (id, distance) pairs.
return [(key, -mdist) for mdist, key in candidates]
def _search_ef1(
self,
query_point: np.ndarray,
entry_point: Hashable,
entry_point_dist: float,
layer: _Layer,
allow_soft_deleted: bool = False,
key_to_hard_delete: Optional[Hashable] = None,
) -> Tuple[Hashable, float]:
"""The greedy search algorithm for finding the closest neighbor only.
Args:
query (np.ndarray): The query point.
entry_point (Hashable): The entry point for the search.
entry_point_dist (float): The distance from the query point to the
entry point.
layer (_Layer): The graph for the layer.
allow_soft_deleted (bool): Whether to allow soft-deleted points to
be returned.
key_to_hard_delete (Optional[Hashable]): The key of the point to be
hard-deleted, if any. This point should never be returned.
Returns:
Tuple[Hashable, float]: A tuple of (key, distance) representing the closest
neighbor found.
"""
candidates = [(entry_point_dist, entry_point)]
visited = set([entry_point])
best = entry_point
best_dist = entry_point_dist
while candidates:
# Pop the closest node from the heap.
dist, curr = heapq.heappop(candidates)
if dist > best_dist:
# Terminate the search if the distance to the current closest node
# is larger than the distance to the best node.
break
# Find the neighbors of the current node
neighbors = [p for p in layer[curr] if p not in visited]
visited.update(neighbors)
dists = [
self._distance_func(query_point, self._nodes[p].point)
for p in neighbors
]
for p, d in zip(neighbors, dists):
# Update the best node if we find a closer node.
if d < best_dist:
if (not allow_soft_deleted and self._nodes[p].is_deleted) or (
p == key_to_hard_delete
):
# If the neighbor has been deleted or to be hard-deleted,
# we don't update the best node but we continue to
# explore the neighbor's neighbors.
pass
else:
best, best_dist = p, d
# Add the neighbor to the heap.
heapq.heappush(candidates, (d, p))
return best, best_dist
def _search_base_layer(
self,
query_point: np.ndarray,
entry_points: List[Tuple[float, Hashable]],
layer: _Layer,
ef: int,
allow_soft_deleted: bool = False,
key_to_hard_delete: Optional[Hashable] = None,
) -> List[Tuple[float, Hashable]]:
"""The ef search algorithm for finding neighbors in a given layer.
Args:
query (np.ndarray): The query point.
entry_points (List[Tuple[float, Hashable]]): A list of (-distance, key) pairs
representing the entry points for the search.
layer (_Layer): The graph for the layer.
ef (int): The number of neighbors to consider during search.
allow_soft_deleted (bool): Whether to allow soft-deleted points to
be returned.
key_to_hard_delete (Optional[Hashable]): The key of the point to be
hard-deleted, if any. This point should never be returned.
Returns:
List[Tuple[float, Hashable]]: A heap of (-distance, key) pairs representing
the neighbors found.
Note:
When used together with :meth:`_search_ef1`, the input entry_points
may contain soft-deleted points depending on the flag used in
:meth:`_search_ef1`. Therefore, the output entry_points may contain
soft-deleted points even if allow_soft_deleted is False. Therefore,
the caller should check input entry_points for soft-deleted
points if necessary.
"""
# candidates is a heap of (distance, key) pairs.
candidates = [(-mdist, p) for mdist, p in entry_points]
heapq.heapify(candidates)
visited = set(p for _, p in entry_points)
while candidates:
# Pop the closest node from the heap.
dist, curr_key = heapq.heappop(candidates)
# If the neighbor has been marked as deleted, we ,
# Terminate the search if the distance to the current closest node
# is larger than the distance to the best node.
closet_dist = -entry_points[0][0]
if dist > closet_dist:
break
# Find the neighbors of the current node
neighbors = [p for p in layer[curr_key] if p not in visited]
visited.update(neighbors)
dists = [
self._distance_func(query_point, self._nodes[p].point)
for p in neighbors
]
for p, dist in zip(neighbors, dists):
if (not allow_soft_deleted and self._nodes[p].is_deleted) or (
p == key_to_hard_delete
):
if dist <= closet_dist:
# If the neighbor has been deleted or to be deleted,
# we add it to the heap to explore its neighbors but
# do not add it to the entry points.
heapq.heappush(candidates, (dist, p))
elif len(entry_points) < ef:
heapq.heappush(candidates, (dist, p))
# If we have not found enough neighbors, we add the neighbor
# to the heap.
heapq.heappush(entry_points, (-dist, p))
closet_dist = -entry_points[0][0]
elif dist <= closet_dist:
heapq.heappush(candidates, (dist, p))
# If we have found enough neighbors, we replace the worst
# neighbor with the neighbor if the neighbor is closer.
heapq.heapreplace(entry_points, (-dist, p))
closet_dist = -entry_points[0][0]
return entry_points
def _heuristic_prune(
self, candidates: List[Tuple[float, Hashable]], max_size: int
) -> List[Tuple[float, Hashable]]:
"""Prune the potential neigbors to keep only the top max_size neighbors.
This algorithm is based on hnswlib's heuristic pruning algorithm:
<https://github.com/nmslib/hnswlib/blob/978f7137bc9555a1b61920f05d9d0d8252ca9169/hnswlib/hnswalg.h#L382>`_.
Args:
candidates (List[Tuple[float, Hashable]]): A list of (distance, key) pairs
representing the potential neighbors.
max_size (int): The maximum number of neighbors to keep.
Returns:
List[Tuple[float, Hashable]]: A list of (distance, key) pairs representing
the pruned neighbors.
"""
if len(candidates) < max_size:
# If the number of entry points is smaller than max_size, we return
# all entry points.
return candidates
# candidates is a heap of (distance, key) pairs.
heapq.heapify(candidates)
pruned = []
while candidates:
if len(pruned) >= max_size:
break
# Pop the closest node from the heap.
candidate_dist, candidate_key = heapq.heappop(candidates)
good = True
for _, selected_key in pruned:
dist_to_selected_neighbor = self._distance_func(
self._nodes[selected_key].point, self._nodes[candidate_key].point
)
if dist_to_selected_neighbor < candidate_dist:
good = False
break
if good:
pruned.append((candidate_dist, candidate_key))
return pruned
[docs]
def remove(
self,
key: Hashable,
hard: bool = False,
ef: Optional[int] = None,
) -> None:
"""Remove a point from the index. This removal algorithm is based on
the discussion on `hnswlib issue #4`_. There are two versions:
* *soft remove*: the point is marked as removed from the index, but its
data and out-going edges are kept. Future queries will not return
the point and no new edge will direct to this point,
but the point will still be used for graph traversal.
This is the default behavior.
* *hard remove*: the point is removed from the index and its data and
out-going edges are also removed. Points with out-going edges pointing
to the deleted point will have their out-going edges
re-assigned using the same pruning algorithm as :meth:`insert` during
point update.
In both versions, if the deleted point is the current entry point,
the entry point will be re-assigned to the next point in the highest
layer that has other points beside the current entry point.
Subsequent soft removes without a hard remove of the same point will
not affect the index, **unless the point was the only point in the index
as removing it clears the index**. This is different from :meth:`pop`
which will always raise a KeyError if the key was removed.
Subsequent hard removes of the same point will
raise a KeyError. If the point is soft removed and then hard removed,
the point will be removed from the index.
Use :meth:`clean` for removing all soft removed points from the index.
Args:
key (Hashable): The key of the point to remove.
hard (bool): If True, perform a hard remove. Otherwise, perform a
soft remove.
ef (Optional[int]): The number of neighbors to consider during
re-assignment. If None, use the construction ef. This argument
is only used when hard is True.
Raises:
KeyError: If the index is empty or the key does not exist in the
index and was not soft removed.
Example:
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data = np.random.random_sample((1000, 10))
index = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
index.update({i: d for i, d in enumerate(data)})
# Soft remove a point with key = 0.
index.remove(0)
# Soft remove the same point again will not change the index
# because the index is not empty.
index.remove(0)
print(0 in index) # False
# Hard remove the point.
index.remove(0, hard=True)
# Hard remove the same point again will raise a KeyError.
# index.remove(0, hard=True)
# Soft remove rest of the points from the index.
for i in range(1, 1000):
index.remove(i)
print(len(index)) # 0
# Clean the index to hard remove all soft removed points.
index.clean()
.. _hnswlib issue #4: https://github.com/nmslib/hnswlib/issues/4
"""
if not self._nodes or key not in self._nodes:
raise KeyError(key)
if self._entry_point == key:
# If the point is the entry point, we re-assign the entry point
# to the next point in the highest layer besides the point to be
# deleted.
new_entry_point = None
for layer in reversed(list(self._graphs)):
new_entry_point = next(
(p for p in layer if p != key and not self._nodes[p].is_deleted),
None,
)
if new_entry_point is not None:
break
else:
# As the layer is going to be empty after deletion, we remove it.
self._graphs.pop()
if new_entry_point is None:
# If the point to be deleted is the only point in the index,
# we clear the index.
self.clear()
return
# Update the entry point.
self._entry_point = new_entry_point
if ef is None:
ef = self._ef_construction
# Soft remove.
self._nodes[key].is_deleted = True
if not hard:
return
# Hard remove.
# Find all points that have out-going edges pointing to the deleted point.
keys_to_update = set()
for layer in self._graphs:
if key not in layer:
break
keys_to_update.update(layer.get_reverse_edges(key))
# Re-assign edges for these points.
for key_to_update in keys_to_update:
self._repair_connections(
key_to_update,
self._nodes[key_to_update].point,
ef,
key_to_delete=key,
)
# Remove the point to be deleted from the grpah.
for layer in self._graphs:
if key not in layer:
break
del layer[key]
# Remove the point from the index.
del self._nodes[key]
[docs]
def clean(self, ef: Optional[int] = None) -> None:
"""Remove all soft removed points from the index.
Args:
ef (Optional[int]): The number of neighbors to consider during
re-assignment. If None, use the construction ef.
"""
keys_to_remove = list(key for key in self._nodes if self._nodes[key].is_deleted)
for key in keys_to_remove:
self.remove(key, ef=ef, hard=True)
[docs]
def merge(self, other: HNSW) -> HNSW:
"""Create a new index by merging the current index with another index.
The new index will contain all points from both indexes.
If a point exists in both, the point from the other index will be used.
The new index will have the same parameters as the current index and
a copy of the current index's random state.
Args:
other (HNSW): The other index to merge with.
Returns:
HNSW: A new index containing all points from both indexes.
Example:
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data1 = np.random.random_sample((1000, 10))
data2 = np.random.random_sample((1000, 10))
index1 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
index2 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
# Batch insert data into the indexes.
index1.update({i: d for i, d in enumerate(data1)})
index2.update({i + len(data1): d for i, d in enumerate(data2)})
# Merge the indexes.
index = index1.merge(index2)
"""
new_index = self.copy()
new_index.update(other)
return new_index