1 #ifndef BTLLIB_BLOOM_FILTER_HPP
2 #define BTLLIB_BLOOM_FILTER_HPP
7 #include "vendor/cpptoml.hpp"
26 void insert(
const std::vector<uint64_t>& hashes);
27 void insert(
const uint64_t* hashes);
29 bool contains(
const std::vector<uint64_t>& hashes);
30 bool contains(
const uint64_t* hashes);
32 size_t get_bytes()
const {
return bytes; }
33 uint64_t get_pop_cnt()
const;
34 unsigned get_hash_num()
const {
return hash_num; }
35 double get_fpr()
const;
37 void write(
const std::string& path);
40 std::vector<unsigned char> bytearray;
42 unsigned hash_num = 0;
54 void insert(
const std::vector<uint64_t>& hashes);
55 void insert(
const uint64_t* hashes);
57 T contains(
const std::vector<uint64_t>& hashes);
58 T contains(
const uint64_t* hashes);
60 size_t get_bytes()
const {
return bytes; }
61 uint64_t get_pop_cnt()
const;
62 unsigned get_hash_num()
const {
return hash_num; }
63 double get_fpr()
const;
69 void write(
const std::string& path);
72 std::vector<unsigned char> bytearray;
75 unsigned hash_num = 0;
98 void insert(
const std::string& seq);
105 void insert(
const char* seq,
size_t seq_len);
113 unsigned contains(
const std::string& seq);
122 unsigned contains(
const char* seq,
size_t seq_len);
136 void insert(
const std::string& seq);
137 void insert(
const char* seq,
size_t seq_len);
139 uint64_t contains(
const std::string& seq);
140 uint64_t contains(
const char* seq,
size_t seq_len);
155 static const unsigned char BIT_MASKS[CHAR_BIT] = {
157 0x01, 0x02, 0x04, 0x08,
158 0x10, 0x20, 0x40, 0x80
160 static const char* BLOOM_FILTER_MAGIC_HEADER =
"BTLBloomFilter_v2";
161 static const char* COUNTING_BLOOM_FILTER_MAGIC_HEADER =
162 "BTLCountingBloomFilter_v2";
165 pop_cnt_byte(
unsigned char x)
167 return ((0x876543210 >>
168 (((0x4332322132212110 >> ((x & 0xF) << 2)) & 0xF) << 2)) >>
169 ((0x4332322132212110 >> (((x & 0xF0) >> 2)) & 0xF) << 2)) &
173 inline BloomFilter::BloomFilter(
size_t bytes,
unsigned hash_num)
174 : bytes(std::ceil(bytes / sizeof(uint64_t)) * sizeof(uint64_t))
177 bytearray.resize(bytes);
180 inline BloomFilter::BloomFilter(
const std::string& path)
182 std::ifstream file(path);
184 std::string magic_with_brackets =
185 std::string(
"[") + BLOOM_FILTER_MAGIC_HEADER +
"]";
188 std::getline(file, line);
189 if (line != magic_with_brackets) {
191 std::string(
"Magic string does not match (likely version mismatch)\n") +
192 "Your magic string:\t" + line +
"\n" +
"BloomFilter magic string:\t" +
193 magic_with_brackets);
194 std::exit(EXIT_FAILURE);
200 std::string toml_buffer(line +
'\n');
201 bool header_end_found =
false;
202 while (std::getline(file, line)) {
203 toml_buffer.append(line +
'\n');
204 if (line ==
"[HeaderEnd]") {
205 header_end_found =
true;
209 if (!header_end_found) {
210 log_error(
"Pre-built bloom filter does not have the correct header end.");
211 std::exit(EXIT_FAILURE);
215 std::istringstream toml_stream(toml_buffer);
216 cpptoml::parser toml_parser(toml_stream);
217 auto header_config = toml_parser.parse();
220 auto table = header_config->get_table(BLOOM_FILTER_MAGIC_HEADER);
221 bytes = *table->get_as<
size_t>(
"bytes");
222 hash_num = *table->get_as<
unsigned>(
"hash_num");
224 bytearray.resize(bytes);
225 file.read((
char*)bytearray.data(), bytes);
229 BloomFilter::insert(
const std::vector<uint64_t>& hashes)
231 insert(hashes.data());
235 BloomFilter::insert(
const uint64_t* hashes)
237 for (
unsigned i = 0; i < hash_num; ++i) {
238 auto normalized = hashes[i] % bytes;
239 __sync_or_and_fetch(&(bytearray[normalized / CHAR_BIT]),
240 BIT_MASKS[normalized % CHAR_BIT]);
245 BloomFilter::contains(
const std::vector<uint64_t>& hashes)
247 return contains(hashes.data());
251 BloomFilter::contains(
const uint64_t* hashes)
253 for (
unsigned i = 0; i < hash_num; ++i) {
254 auto normalized = hashes[i] % bytes;
255 auto mask = BIT_MASKS[normalized % CHAR_BIT];
256 if (!
bool(bytearray[normalized / CHAR_BIT] & mask)) {
264 BloomFilter::get_pop_cnt()
const
266 uint64_t pop_cnt = 0;
267 #pragma omp parallel for reduction(+ : pop_cnt)
268 for (
size_t i = 0; i < bytes; ++i) {
269 pop_cnt += pop_cnt_byte(bytearray[i]);
275 BloomFilter::get_fpr()
const
277 return std::pow(
double(get_pop_cnt()) /
double(bytes),
double(hash_num));
281 BloomFilter::write(
const std::string& path)
283 std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
289 auto root = cpptoml::make_table();
293 auto header = cpptoml::make_table();
294 header->insert(
"bytes", bytes);
295 header->insert(
"hash_num", hash_num);
296 root->insert(BLOOM_FILTER_MAGIC_HEADER, header);
297 file << *root <<
"[HeaderEnd]\n";
299 file.write((
char*)bytearray.data(), bytes);
303 inline CountingBloomFilter<T>::CountingBloomFilter(
size_t bytes,
305 : bytes(std::ceil(bytes / sizeof(uint64_t)) * sizeof(uint64_t))
306 , counters(bytes / sizeof(T))
309 bytearray.resize(bytes);
313 inline CountingBloomFilter<T>::CountingBloomFilter(
const std::string& path)
315 std::ifstream file(path);
317 std::string magic_with_brackets =
318 std::string(
"[") + COUNTING_BLOOM_FILTER_MAGIC_HEADER +
"]";
321 std::getline(file, line);
322 if (line != magic_with_brackets) {
324 std::string(
"Magic string does not match (likely version mismatch)\n") +
325 "Your magic string:\t" + line +
"\n" +
"BloomFilter magic string:\t" +
326 magic_with_brackets);
327 std::exit(EXIT_FAILURE);
333 std::string toml_buffer(line +
'\n');
334 bool header_end_found =
false;
335 while (std::getline(file, line)) {
336 toml_buffer.append(line +
'\n');
337 if (line ==
"[HeaderEnd]") {
338 header_end_found =
true;
342 if (!header_end_found) {
343 log_error(
"Pre-built bloom filter does not have the correct header end.");
344 std::exit(EXIT_FAILURE);
348 std::istringstream toml_stream(toml_buffer);
349 cpptoml::parser toml_parser(toml_stream);
350 auto header_config = toml_parser.parse();
353 auto table = header_config->get_table(COUNTING_BLOOM_FILTER_MAGIC_HEADER);
354 bytes = *table->get_as<
size_t>(
"bytes");
355 hash_num = *table->get_as<
unsigned>(
"hash_num");
356 counters = bytes /
sizeof(T);
357 check_error(
sizeof(T) * CHAR_BIT != *table->get_as<
size_t>(
"counter_bits"),
358 "CountingBloomFilter" + std::to_string(
sizeof(T) * CHAR_BIT) +
359 " tried to load a file of CountingBloomFilter" +
360 std::to_string(*table->get_as<
size_t>(
"counter_bits")));
362 bytearray.resize(bytes);
363 file.read((
char*)bytearray.data(), bytes);
368 CountingBloomFilter<T>::insert(
const std::vector<uint64_t>& hashes)
370 insert(hashes.data());
375 CountingBloomFilter<T>::insert(
const uint64_t* hashes)
378 bool update_done =
false;
380 T min_val = contains(hashes);
381 while (!update_done) {
383 new_val = min_val + 1;
384 if (min_val > new_val) {
387 for (
size_t i = 0; i < hash_num; ++i) {
388 if (__sync_bool_compare_and_swap(
389 &(((T*)(bytearray.data()))[hashes[i] % counters]),
398 min_val = contains(hashes);
405 CountingBloomFilter<T>::contains(
const std::vector<uint64_t>& hashes)
407 return contains(hashes.data());
412 CountingBloomFilter<T>::contains(
const uint64_t* hashes)
414 T min = ((T*)(bytearray.data()))[hashes[0] % counters];
415 for (
size_t i = 1; i < hash_num; ++i) {
416 size_t idx = hashes[i] % counters;
417 if (((T*)(bytearray.data()))[idx] < min) {
418 min = ((T*)(bytearray.data()))[idx];
426 CountingBloomFilter<T>::get_pop_cnt()
const
428 uint64_t pop_cnt = 0;
429 #pragma omp parallel for reduction(+ : pop_cnt)
430 for (
size_t i = 0; i < counters; ++i) {
431 if (((T*)(bytearray.data()))[i]) {
440 CountingBloomFilter<T>::get_fpr()
const
442 return std::pow(
double(get_pop_cnt()) /
double(bytes),
double(hash_num));
449 std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
455 auto root = cpptoml::make_table();
459 auto header = cpptoml::make_table();
460 header->insert(
"bytes", bytes);
461 header->insert(
"hash_num", hash_num);
462 header->insert(
"counter_bits",
size_t(
sizeof(T) * CHAR_BIT));
463 root->insert(COUNTING_BLOOM_FILTER_MAGIC_HEADER, header);
464 file << *root <<
"[HeaderEnd]\n";
466 file.write((
char*)bytearray.data(), bytes);
473 , bf(bytes, hash_num)
479 insert(seq.c_str(), seq.size());
485 NtHash nthash(seq, seq_len, k, bf.get_hash_num());
486 while (nthash.
roll()) {
487 bf.insert(nthash.hashes());
494 return contains(seq.c_str(), seq.size());
500 NtHash nthash(seq, seq_len, k, bf.get_hash_num());
501 while (nthash.
roll()) {
502 if (bf.contains(nthash.hashes())) {
514 , bf(bytes, hash_num)
519 KmerCountingBloomFilter<T>::insert(
const std::string& seq)
521 insert(seq.c_str(), seq.size());
526 KmerCountingBloomFilter<T>::insert(
const char* seq,
size_t seq_len)
528 NtHash nthash(seq, seq_len, k, bf.get_hash_num());
529 while (nthash.roll()) {
530 bf.insert(nthash.hashes());
536 KmerCountingBloomFilter<T>::contains(
const std::string& seq)
538 return contains(seq.c_str(), seq.size());
543 KmerCountingBloomFilter<T>::contains(
const char* seq,
size_t seq_len)
546 NtHash nthash(seq, seq_len, k, bf.get_hash_num());
547 while (nthash.roll()) {
548 count += bf.contains(nthash.hashes());