Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jegao/label host fix with main3 #545

Draft
wants to merge 63 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
237c9c7
return max value while label is invalid
Sanhaoji2 Apr 12, 2023
013db57
hotfix universal label
Sanhaoji2 Apr 27, 2023
6811045
reflector label looping
Sanhaoji2 May 9, 2023
e402c8b
test bit mask
Sanhaoji2 May 10, 2023
b43e493
fix some issue
Sanhaoji2 May 10, 2023
77aded7
simple bitmask
Sanhaoji2 May 11, 2023
c815b8a
fix issue
Sanhaoji2 May 11, 2023
ad0b4ed
limit to one input label
Sanhaoji2 May 11, 2023
1bc3ff2
fix issue
Sanhaoji2 May 11, 2023
a944bd9
fix some issue
Sanhaoji2 May 11, 2023
31cd3c8
move id check first
Sanhaoji2 May 11, 2023
079881b
fix issue
Sanhaoji2 May 11, 2023
73a230f
test remove filter code
Sanhaoji2 May 11, 2023
2963c9d
update log
Sanhaoji2 May 11, 2023
5e87392
fix logging
Sanhaoji2 May 11, 2023
7432760
revert test code
Sanhaoji2 May 11, 2023
cd1ff23
insert filted node to visited map
Sanhaoji2 May 11, 2023
4823436
change to vector to hold bitmask
Sanhaoji2 May 12, 2023
81ff09c
pre-create bitmask object
Sanhaoji2 May 12, 2023
e0e7036
revert pre-create bitmask
Sanhaoji2 May 12, 2023
fe19ce6
support multiple label
Sanhaoji2 May 12, 2023
56e892a
use local bitmask in small size
Sanhaoji2 May 15, 2023
f32b047
add header file
Sanhaoji2 May 15, 2023
9072f16
fix some issue
Sanhaoji2 May 15, 2023
ec2bcf0
add prefetch
Sanhaoji2 May 15, 2023
47e430f
revert perfetch
Sanhaoji2 May 15, 2023
79c3bc9
cleanup code
Sanhaoji2 May 16, 2023
417a47a
apply bitmask in index build
Sanhaoji2 May 17, 2023
a2993df
fix compile issue
Sanhaoji2 May 17, 2023
a4855aa
Fix some issue
Sanhaoji2 May 17, 2023
2dcb482
use original label id
Sanhaoji2 May 17, 2023
94a5a12
return max value for invaild label
Sanhaoji2 May 18, 2023
3681b56
cleanup code
Sanhaoji2 May 18, 2023
07938b9
Jegao/label hot fix with main2 (#430)
Sanhaoji2 Aug 18, 2023
daa5a7b
Jegao/label hot fix test2 (#469)
Sanhaoji2 Oct 8, 2023
155f7bd
add label check API (#476)
Sanhaoji2 Oct 19, 2023
337d0d5
Fix parse issue while only one label in node (#488)
Sanhaoji2 Nov 9, 2023
9bb0cf0
Fix memory leak (#497)
Sanhaoji2 Nov 30, 2023
416c661
apply label parse improve to memory index (#524)
Sanhaoji2 Mar 7, 2024
d78970f
sync to latest code
Sanhaoji2 Apr 10, 2024
08333d8
fix compile issue
Sanhaoji2 Apr 10, 2024
4e63bfe
add interface
Sanhaoji2 Apr 12, 2024
0ad3ec2
add interface
Sanhaoji2 Apr 12, 2024
bb83da9
change inteface
Sanhaoji2 Apr 12, 2024
9a87dc1
move function to public
Sanhaoji2 Apr 15, 2024
0e83b89
remove hard code unv label num
Sanhaoji2 Apr 17, 2024
6c050a1
fix convert issue
Sanhaoji2 Apr 19, 2024
4b4bed5
fix some issue
Sanhaoji2 Apr 19, 2024
aed8d4b
fix issues
Sanhaoji2 Apr 22, 2024
25e3af6
fix issues
Sanhaoji2 Apr 22, 2024
abda8bb
tune perf
Sanhaoji2 Apr 24, 2024
469ec02
test remove lock
Sanhaoji2 Apr 24, 2024
7c05484
try shared lock
Sanhaoji2 Apr 25, 2024
363c59e
change to shared lock
Sanhaoji2 Apr 25, 2024
286db31
try perfetch
Sanhaoji2 Apr 25, 2024
7cfa7f6
fix some issues
Sanhaoji2 Apr 25, 2024
e9c5c44
fix issue
Sanhaoji2 Apr 25, 2024
ec577f1
skip unfilter search while Lindex = 1
Sanhaoji2 Apr 27, 2024
3fa9d42
reserve queue size with max search lsit
Sanhaoji2 Apr 28, 2024
6e36270
Merge branch 'main' of https://github.com/microsoft/DiskANN into jega…
Sanhaoji2 Apr 30, 2024
e2b3007
revert change
Sanhaoji2 Apr 30, 2024
209a15a
revert change
Sanhaoji2 Apr 30, 2024
6863ba6
clean up code
Sanhaoji2 Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/abstract_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ template <typename data_t> class AbstractDataStore
// streaming setting
virtual void get_vector(const location_t i, data_t *dest) const = 0;
virtual void set_vector(const location_t i, const data_t *const vector) = 0;
virtual void prefetch_vector(const location_t loc) = 0;
virtual void prefetch_vector(const location_t loc) const = 0;

// internal shuffle operations to move around vectors
// will bulk-move all the vectors in [old_start_loc, old_start_loc +
Expand Down
3 changes: 3 additions & 0 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class AbstractIndex

template <typename label_type> void set_universal_label(const label_type universal_label);

virtual bool is_label_valid(const std::string &raw_label) const = 0;
virtual bool is_set_universal_label() const = 0;

private:
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0;
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
Expand Down
2 changes: 1 addition & 1 deletion include/in_mem_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_

virtual void get_vector(const location_t i, data_t *target) const override;
virtual void set_vector(const location_t i, const data_t *const vector) override;
virtual void prefetch_vector(const location_t loc) override;
virtual void prefetch_vector(const location_t loc) const override;

virtual void move_vectors(const location_t old_location_start, const location_t new_location_start,
const location_t num_points) override;
Expand Down
128 changes: 126 additions & 2 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "in_mem_data_store.h"
#include "in_mem_graph_store.h"
#include "abstract_index.h"
#include <bitset>

#include "quantized_distance.h"
#include "pq_data_store.h"
Expand All @@ -42,6 +43,117 @@ inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, u
return OVERHEAD_FACTOR * (size_of_data + size_of_graph + size_of_locks + size_of_outer_vector);
}

struct simple_bitmask_val
{
size_t _index = 0;
std::uint64_t _mask = 0;
};

struct simple_bitmask_full_val
{
simple_bitmask_full_val()
{
}

void merge_bitmask_val(simple_bitmask_val& bitmask_val)
{
_mask[bitmask_val._index] |= bitmask_val._mask;
}

std::uint64_t* _mask = nullptr;
};

struct simple_bitmask_buf
{
std::uint64_t* get_bitmask(std::uint64_t index)
{
return _buf.data() + index * _bitmask_size;
}

std::vector<std::uint64_t> _buf;
std::uint64_t _bitmask_size = 0;

};

class simple_bitmask
{
public:
simple_bitmask(std::uint64_t* bitsets, std::uint64_t bitmask_size)
: _bitsets(bitsets)
, _bitmask_size(bitmask_size)
{
}

bool test(size_t pos) const
{
std::uint64_t mask = (std::uint64_t)1 << (pos & (8 * sizeof(std::uint64_t) - 1));
size_t index = pos / 8 / sizeof(std::uint64_t);
std::uint64_t val = _bitsets[index];
return 0 != (val & mask);
}

static simple_bitmask_val get_bitmask_val(size_t pos)
{
simple_bitmask_val bitmask_val;
bitmask_val._mask = (std::uint64_t)1 << (pos & (8 * sizeof(std::uint64_t) - 1));
bitmask_val._index = pos / 8 / sizeof(std::uint64_t);

return bitmask_val;
}

static std::uint64_t get_bitmask_size(std::uint64_t totalBits)
{
std::uint64_t bytes = (totalBits + 7) / 8;
std::uint64_t aligned_bytes = bytes + sizeof(std::uint64_t) - 1;
aligned_bytes = aligned_bytes - (aligned_bytes % sizeof(std::uint64_t));
return aligned_bytes / sizeof(std::uint64_t);
}

bool test_mask_val(const simple_bitmask_val& bitmask_val) const
{
std::uint64_t val = _bitsets[bitmask_val._index];
return 0 != (val & bitmask_val._mask);
}

bool test_full_mask_val(const simple_bitmask_full_val& bitmask_full_val) const
{
for (size_t i = 0; i < _bitmask_size; i++)
{
if ((bitmask_full_val._mask[i] & _bitsets[i]) != 0)
{
return true;
}
}

return false;
}

bool test_full_mask_contain(const simple_bitmask& bitmask_full_val) const
{
for (size_t i = 0; i < _bitmask_size; i++)
{
auto mask = bitmask_full_val._bitsets[i];
if ((mask & _bitsets[i]) != mask)
{
return false;
}
}

return true;
}

void set(size_t pos)
{
std::uint64_t mask = (std::uint64_t)1 << (pos & (8 * sizeof(std::uint64_t) - 1));
size_t index = pos / 8 / sizeof(std::uint64_t);
_bitsets[index] |= mask;
}

private:
std::uint64_t* _bitsets;
std::uint64_t _bitmask_size;
};

template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> class Index : public AbstractIndex
{
/**************************************************************************
Expand Down Expand Up @@ -112,7 +224,11 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);

// Get converted integer label from string to int map (_label_map)
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label);
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label) const;

DISKANN_DLLEXPORT bool is_label_valid(const std::string& raw_label) const override;

DISKANN_DLLEXPORT bool is_set_universal_label() const override;

// Set starting point of an index before inserting any points incrementally.
// The data count should be equal to _num_frozen_pts * _aligned_dim.
Expand Down Expand Up @@ -249,6 +365,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

void parse_label_file(const std::string &label_file, size_t &num_pts_labels);

void parse_label_file_in_bitset(const std::string& label_file, size_t& num_points, size_t num_labels);

void convert_pts_label_to_bitmask(std::vector<std::vector<LabelT>>& pts_to_labels, simple_bitmask_buf& bitmask_buf, size_t num_labels);

std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);

// Returns the locations of start point and frozen points suitable for use
Expand Down Expand Up @@ -312,7 +432,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
const uint32_t maxc, const float alpha, InMemQueryScratch<T> *scratch);

void initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, uint32_t r,
uint32_t maxc, size_t dim);
uint32_t maxc, size_t dim, size_t bitmask_size = 0);

// Do not call without acquiring appropriate locks
// call public member functions save and load to invoke these.
Expand All @@ -332,6 +452,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename);
#endif

size_t search_string_range(const std::string& str, char ch, size_t start, size_t end);

private:
// Distance functions
Metric _dist_metric = diskann::L2;
Expand Down Expand Up @@ -443,6 +565,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// Per node lock, cardinality=_max_points + _num_frozen_points
std::vector<non_recursive_mutex> _locks;

simple_bitmask_buf _bitmask_buf;

static const float INDEX_GROWTH_FACTOR;
};
} // namespace diskann
2 changes: 1 addition & 1 deletion include/pq_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>

virtual void get_vector(const location_t i, data_t *target) const override;
virtual void set_vector(const location_t i, const data_t *const vector) override;
virtual void prefetch_vector(const location_t loc) override;
virtual void prefetch_vector(const location_t loc) const override;

virtual void move_vectors(const location_t old_location_start, const location_t new_location_start,
const location_t num_points) override;
Expand Down
17 changes: 12 additions & 5 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
const char *compressed_filepath);
const char* index_filepath, const char* pivots_filepath,
const char* compressed_filepath, const char* labels_filepath, const char* labels_to_medoids_filepath,
const char* labels_map_filepath, const char* unv_label_filepath);
#else
DISKANN_DLLEXPORT int load_from_separate_paths(uint32_t num_threads, const char *index_filepath,
const char *pivots_filepath, const char *compressed_filepath);
const char* pivots_filepath, const char* compressed_filepath,
const char* labels_filepath, const char* labels_to_medoids_filepath,
const char* labels_map_filepath, const char* unv_label_filepath);
#endif

DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);
Expand Down Expand Up @@ -83,6 +86,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);

DISKANN_DLLEXPORT bool is_label_valid(const std::string& filter_label);

DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search,
const uint64_t max_l_search, std::vector<uint64_t> &indices,
std::vector<float> &distances, const uint64_t min_beam_width,
Expand Down Expand Up @@ -116,8 +121,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char>& infile);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char>& infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
Expand All @@ -136,6 +141,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
// returns region of `node_buf` containing [COORD(T)]
DISKANN_DLLEXPORT T *offset_to_node_coords(char *node_buf);

size_t search_string_range(const std::string& str, char ch, size_t start, size_t end);

// index info for multi-node sectors
// nhood of node `i` is in sector: [i / nnodes_per_sector]
// offset in sector: [(i % nnodes_per_sector) * max_node_len]
Expand Down
9 changes: 8 additions & 1 deletion include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
public:
~InMemQueryScratch();
InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim,
size_t alignment_factor, bool init_pq_scratch = false);
size_t alignment_factor, bool init_pq_scratch = false, size_t bitmask_size = 0);
void resize_for_new_L(uint32_t new_search_l);
void clear();

Expand Down Expand Up @@ -94,6 +94,11 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
return _occlude_list_output;
}

inline std::vector<std::uint64_t>& query_label_bitmask()
{
return _query_label_bitmask;
}

private:
uint32_t _L;
uint32_t _R;
Expand Down Expand Up @@ -132,6 +137,8 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
tsl::robin_set<uint32_t> _expanded_nodes_set;
std::vector<Neighbor> _expanded_nghrs_vec;
std::vector<uint32_t> _occlude_list_output;
// bitmask buffer in searching time
std::vector<std::uint64_t> _query_label_bitmask;
};

//
Expand Down
12 changes: 9 additions & 3 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,14 @@ inline int delete_file(const std::string &fileName)

// generates formatted_label and _labels_map file.
inline void convert_labels_string_to_int(const std::string &inFileName, const std::string &outFileName,
const std::string &mapFileName, const std::string &unv_label)
const std::string &mapFileName, const std::string &unv_label,
uint32_t& unv_label_id)
{
std::unordered_map<std::string, uint32_t> string_int_map;
std::ofstream label_writer(outFileName);
std::ifstream label_reader(inFileName);
if (unv_label != "")
string_int_map[unv_label] = 0; // if universal label is provided map it to 0 always
//if (unv_label != "")
// string_int_map[unv_label] = 0;
std::string line, token;
while (std::getline(label_reader, line))
{
Expand Down Expand Up @@ -217,6 +218,11 @@ inline void convert_labels_string_to_int(const std::string &inFileName, const st
}
label_writer.close();

if (unv_label != "")
{
unv_label_id = string_int_map[unv_label];
}

std::ofstream map_writer(mapFileName);
for (auto mp : string_int_map)
{
Expand Down
10 changes: 10 additions & 0 deletions include/windows_slim_lock.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class windows_exclusive_slim_lock
return AcquireSRWLockExclusive(&_lock);
}

void lock_shared()
{
return AcquireSRWLockShared(&_lock);
}

bool try_lock()
{
return TryAcquireSRWLockExclusive(&_lock) != FALSE;
Expand All @@ -44,6 +49,11 @@ class windows_exclusive_slim_lock
return ReleaseSRWLockExclusive(&_lock);
}

void unlock_shared()
{
return ReleaseSRWLockShared(&_lock);
}

private:
SRWLOCK _lock;
};
Expand Down
15 changes: 8 additions & 7 deletions src/disk_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq,
uint32_t num_threads, bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file, const std::string &universal_label,
const uint32_t Lf)
const uint32_t Lf, uint32_t universal_label_num = 0)
{
size_t base_num, base_dim;
diskann::get_bin_metadata(base_file, base_num, base_dim);
Expand Down Expand Up @@ -659,8 +659,8 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
{
if (universal_label != "")
{ // indicates no universal label
LabelT unv_label_as_num = 0;
_index.set_universal_label(unv_label_as_num);
// LabelT unv_label_as_num = 0;
_index.set_universal_label(universal_label_num);
}
_index.build_filtered_index(base_file.c_str(), label_file, base_num);
}
Expand Down Expand Up @@ -733,8 +733,8 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
diskann::extract_shard_labels(label_file, shard_ids_file, shard_labels_file);
if (universal_label != "")
{ // indicates no universal label
LabelT unv_label_as_num = 0;
_index.set_universal_label(unv_label_as_num);
// LabelT unv_label_as_num = 0;
_index.set_universal_label(universal_label_num);
}
_index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts);
}
Expand Down Expand Up @@ -1266,10 +1266,11 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
// into replica dummy points which evenly distribute the filters. The rest
// of index build happens on the augmented base and labels
std::string augmented_data_file, augmented_labels_file;
std::uint32_t universal_label_id = 0;
if (use_filters)
{
convert_labels_string_to_int(labels_file_original, labels_file_to_use, disk_labels_int_map_file,
universal_label);
universal_label, universal_label_id);
augmented_data_file = index_prefix_path + "_augmented_data.bin";
augmented_labels_file = index_prefix_path + "_augmented_labels.txt";
if (filter_threshold != 0)
Expand Down Expand Up @@ -1326,7 +1327,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
diskann::build_merged_vamana_index<T, LabelT>(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val,
indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use,
labels_to_medoids_path, universal_label, Lf);
labels_to_medoids_path, universal_label, Lf, universal_label_id);
diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl;

timer.reset();
Expand Down
Loading
Loading