btllib
bloom_filter.hpp
1 #ifndef BTLLIB_BLOOM_FILTER_HPP
2 #define BTLLIB_BLOOM_FILTER_HPP
3 
4 #include "nthash.hpp"
5 #include "status.hpp"
6 
7 #include "vendor/cpptoml.hpp"
8 
9 #include <climits>
10 #include <cmath>
11 #include <cstdint>
12 #include <fstream>
13 #include <string>
14 #include <vector>
15 
16 namespace btllib {
17 
19 {
20 
21 public:
22  BloomFilter() {}
23  BloomFilter(size_t bytes, unsigned hash_num);
24  BloomFilter(const std::string& path);
25 
26  void insert(const std::vector<uint64_t>& hashes);
27  void insert(const uint64_t* hashes);
28 
29  bool contains(const std::vector<uint64_t>& hashes);
30  bool contains(const uint64_t* hashes);
31 
32  size_t get_bytes() const { return bytes; }
33  uint64_t get_pop_cnt() const;
34  unsigned get_hash_num() const { return hash_num; }
35  double get_fpr() const;
36 
37  void write(const std::string& path);
38 
39 private:
40  std::vector<unsigned char> bytearray;
41  size_t bytes = 0;
42  unsigned hash_num = 0;
43 };
44 
45 template<typename T>
47 {
48 
49 public:
51  CountingBloomFilter(size_t bytes, unsigned hash_num);
52  CountingBloomFilter(const std::string& path);
53 
54  void insert(const std::vector<uint64_t>& hashes);
55  void insert(const uint64_t* hashes);
56 
57  T contains(const std::vector<uint64_t>& hashes);
58  T contains(const uint64_t* hashes);
59 
60  size_t get_bytes() const { return bytes; }
61  uint64_t get_pop_cnt() const;
62  unsigned get_hash_num() const { return hash_num; }
63  double get_fpr() const;
64 
69  void write(const std::string& path);
70 
71 private:
72  std::vector<unsigned char> bytearray;
73  size_t bytes = 0;
74  size_t counters = 0;
75  unsigned hash_num = 0;
76 };
77 
83 {
84 
85 public:
92  KmerBloomFilter(unsigned k, size_t bytes, unsigned hash_num = 4);
93 
98  void insert(const std::string& seq);
99 
105  void insert(const char* seq, size_t seq_len);
106 
113  unsigned contains(const std::string& seq);
114 
122  unsigned contains(const char* seq, size_t seq_len);
123 
124 private:
125  unsigned k;
126  BloomFilter bf;
127 };
128 
129 template<typename T>
131 {
132 
133 public:
134  KmerCountingBloomFilter(unsigned k, size_t bytes, unsigned hash_num = 4);
135 
136  void insert(const std::string& seq);
137  void insert(const char* seq, size_t seq_len);
138 
139  uint64_t contains(const std::string& seq);
140  uint64_t contains(const char* seq, size_t seq_len);
141 
142 private:
143  unsigned k;
145 };
146 
150 
154 
155 static const unsigned char BIT_MASKS[CHAR_BIT] = {
156  // NOLINT
157  0x01, 0x02, 0x04, 0x08, // NOLINT
158  0x10, 0x20, 0x40, 0x80 // NOLINT
159 };
160 static const char* BLOOM_FILTER_MAGIC_HEADER = "BTLBloomFilter_v2";
161 static const char* COUNTING_BLOOM_FILTER_MAGIC_HEADER =
162  "BTLCountingBloomFilter_v2";
163 
164 inline unsigned
165 pop_cnt_byte(unsigned char x)
166 {
167  return ((0x876543210 >> // NOLINT
168  (((0x4332322132212110 >> ((x & 0xF) << 2)) & 0xF) << 2)) >> // NOLINT
169  ((0x4332322132212110 >> (((x & 0xF0) >> 2)) & 0xF) << 2)) & // NOLINT
170  0xf; // NOLINT
171 }
172 
173 inline BloomFilter::BloomFilter(size_t bytes, unsigned hash_num)
174  : bytes(std::ceil(bytes / sizeof(uint64_t)) * sizeof(uint64_t))
175  , hash_num(hash_num)
176 {
177  bytearray.resize(bytes);
178 }
179 
180 inline BloomFilter::BloomFilter(const std::string& path)
181 {
182  std::ifstream file(path);
183 
184  std::string magic_with_brackets =
185  std::string("[") + BLOOM_FILTER_MAGIC_HEADER + "]";
186 
187  std::string line;
188  std::getline(file, line);
189  if (line != magic_with_brackets) {
190  log_error(
191  std::string("Magic string does not match (likely version mismatch)\n") +
192  "Your magic string:\t" + line + "\n" + "BloomFilter magic string:\t" +
193  magic_with_brackets);
194  std::exit(EXIT_FAILURE);
195  }
196 
197  /* Read bloom filter line by line until it sees "[HeaderEnd]"
198  which is used to mark the end of the header section and
199  assigns the header to a char array*/
200  std::string toml_buffer(line + '\n');
201  bool header_end_found = false;
202  while (std::getline(file, line)) {
203  toml_buffer.append(line + '\n');
204  if (line == "[HeaderEnd]") {
205  header_end_found = true;
206  break;
207  }
208  }
209  if (!header_end_found) {
210  log_error("Pre-built bloom filter does not have the correct header end.");
211  std::exit(EXIT_FAILURE);
212  }
213 
214  // Send the char array to a stringstream for the cpptoml parser to parse
215  std::istringstream toml_stream(toml_buffer);
216  cpptoml::parser toml_parser(toml_stream);
217  auto header_config = toml_parser.parse();
218 
219  // Obtain header values from toml parser and assign them to class members
220  auto table = header_config->get_table(BLOOM_FILTER_MAGIC_HEADER);
221  bytes = *table->get_as<size_t>("bytes");
222  hash_num = *table->get_as<unsigned>("hash_num");
223 
224  bytearray.resize(bytes);
225  file.read((char*)bytearray.data(), bytes);
226 }
227 
228 inline void
229 BloomFilter::insert(const std::vector<uint64_t>& hashes)
230 {
231  insert(hashes.data());
232 }
233 
234 inline void
235 BloomFilter::insert(const uint64_t* hashes)
236 {
237  for (unsigned i = 0; i < hash_num; ++i) {
238  auto normalized = hashes[i] % bytes;
239  __sync_or_and_fetch(&(bytearray[normalized / CHAR_BIT]),
240  BIT_MASKS[normalized % CHAR_BIT]);
241  }
242 }
243 
244 inline bool
245 BloomFilter::contains(const std::vector<uint64_t>& hashes)
246 {
247  return contains(hashes.data());
248 }
249 
250 inline bool
251 BloomFilter::contains(const uint64_t* hashes)
252 {
253  for (unsigned i = 0; i < hash_num; ++i) {
254  auto normalized = hashes[i] % bytes;
255  auto mask = BIT_MASKS[normalized % CHAR_BIT];
256  if (!bool(bytearray[normalized / CHAR_BIT] & mask)) {
257  return false;
258  }
259  }
260  return true;
261 }
262 
263 inline uint64_t
264 BloomFilter::get_pop_cnt() const
265 {
266  uint64_t pop_cnt = 0;
267 #pragma omp parallel for reduction(+ : pop_cnt)
268  for (size_t i = 0; i < bytes; ++i) {
269  pop_cnt += pop_cnt_byte(bytearray[i]);
270  }
271  return pop_cnt;
272 }
273 
274 inline double
275 BloomFilter::get_fpr() const
276 {
277  return std::pow(double(get_pop_cnt()) / double(bytes), double(hash_num));
278 }
279 
280 inline void
281 BloomFilter::write(const std::string& path)
282 {
283  std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
284 
285  /* Initialize cpptoml root table
286  Note: Tables and fields are unordered
287  Ordering of table is maintained by directing the table
288  to the output stream immediately after completion */
289  auto root = cpptoml::make_table();
290 
291  /* Initialize bloom filter section and insert fields
292  and output to ostream */
293  auto header = cpptoml::make_table();
294  header->insert("bytes", bytes);
295  header->insert("hash_num", hash_num);
296  root->insert(BLOOM_FILTER_MAGIC_HEADER, header);
297  file << *root << "[HeaderEnd]\n";
298 
299  file.write((char*)bytearray.data(), bytes);
300 }
301 
302 template<typename T>
303 inline CountingBloomFilter<T>::CountingBloomFilter(size_t bytes,
304  unsigned hash_num)
305  : bytes(std::ceil(bytes / sizeof(uint64_t)) * sizeof(uint64_t))
306  , counters(bytes / sizeof(T))
307  , hash_num(hash_num)
308 {
309  bytearray.resize(bytes);
310 }
311 
312 template<typename T>
313 inline CountingBloomFilter<T>::CountingBloomFilter(const std::string& path)
314 {
315  std::ifstream file(path);
316 
317  std::string magic_with_brackets =
318  std::string("[") + COUNTING_BLOOM_FILTER_MAGIC_HEADER + "]";
319 
320  std::string line;
321  std::getline(file, line);
322  if (line != magic_with_brackets) {
323  log_error(
324  std::string("Magic string does not match (likely version mismatch)\n") +
325  "Your magic string:\t" + line + "\n" + "BloomFilter magic string:\t" +
326  magic_with_brackets);
327  std::exit(EXIT_FAILURE);
328  }
329 
330  /* Read bloom filter line by line until it sees "[HeaderEnd]"
331  which is used to mark the end of the header section and
332  assigns the header to a char array*/
333  std::string toml_buffer(line + '\n');
334  bool header_end_found = false;
335  while (std::getline(file, line)) {
336  toml_buffer.append(line + '\n');
337  if (line == "[HeaderEnd]") {
338  header_end_found = true;
339  break;
340  }
341  }
342  if (!header_end_found) {
343  log_error("Pre-built bloom filter does not have the correct header end.");
344  std::exit(EXIT_FAILURE);
345  }
346 
347  // Send the char array to a stringstream for the cpptoml parser to parse
348  std::istringstream toml_stream(toml_buffer);
349  cpptoml::parser toml_parser(toml_stream);
350  auto header_config = toml_parser.parse();
351 
352  // Obtain header values from toml parser and assign them to class members
353  auto table = header_config->get_table(COUNTING_BLOOM_FILTER_MAGIC_HEADER);
354  bytes = *table->get_as<size_t>("bytes");
355  hash_num = *table->get_as<unsigned>("hash_num");
356  counters = bytes / sizeof(T);
357  check_error(sizeof(T) * CHAR_BIT != *table->get_as<size_t>("counter_bits"),
358  "CountingBloomFilter" + std::to_string(sizeof(T) * CHAR_BIT) +
359  " tried to load a file of CountingBloomFilter" +
360  std::to_string(*table->get_as<size_t>("counter_bits")));
361 
362  bytearray.resize(bytes);
363  file.read((char*)bytearray.data(), bytes);
364 }
365 
366 template<typename T>
367 inline void
368 CountingBloomFilter<T>::insert(const std::vector<uint64_t>& hashes)
369 {
370  insert(hashes.data());
371 }
372 
373 template<typename T>
374 inline void
375 CountingBloomFilter<T>::insert(const uint64_t* hashes)
376 {
377  // Update flag to track if increment is done on at least one counter
378  bool update_done = false;
379  T new_val;
380  T min_val = contains(hashes);
381  while (!update_done) {
382  // Simple check to deal with overflow
383  new_val = min_val + 1;
384  if (min_val > new_val) {
385  return;
386  }
387  for (size_t i = 0; i < hash_num; ++i) {
388  if (__sync_bool_compare_and_swap( // NOLINT
389  &(((T*)(bytearray.data()))[hashes[i] % counters]),
390  min_val,
391  new_val)) { // NOLINT
392  update_done = true;
393  }
394  }
395  // Recalculate minval because if increment fails, it needs a new minval to
396  // use and if it doesnt hava a new one, the while loop runs forever.
397  if (!update_done) {
398  min_val = contains(hashes);
399  }
400  }
401 }
402 
403 template<typename T>
404 inline T
405 CountingBloomFilter<T>::contains(const std::vector<uint64_t>& hashes)
406 {
407  return contains(hashes.data());
408 }
409 
410 template<typename T>
411 inline T
412 CountingBloomFilter<T>::contains(const uint64_t* hashes)
413 {
414  T min = ((T*)(bytearray.data()))[hashes[0] % counters];
415  for (size_t i = 1; i < hash_num; ++i) {
416  size_t idx = hashes[i] % counters;
417  if (((T*)(bytearray.data()))[idx] < min) {
418  min = ((T*)(bytearray.data()))[idx];
419  }
420  }
421  return min;
422 }
423 
424 template<typename T>
425 inline uint64_t
426 CountingBloomFilter<T>::get_pop_cnt() const
427 {
428  uint64_t pop_cnt = 0;
429 #pragma omp parallel for reduction(+ : pop_cnt)
430  for (size_t i = 0; i < counters; ++i) {
431  if (((T*)(bytearray.data()))[i]) {
432  ++pop_cnt;
433  }
434  }
435  return pop_cnt;
436 }
437 
438 template<typename T>
439 inline double
440 CountingBloomFilter<T>::get_fpr() const
441 {
442  return std::pow(double(get_pop_cnt()) / double(bytes), double(hash_num));
443 }
444 
445 template<typename T>
446 inline void
447 CountingBloomFilter<T>::write(const std::string& path)
448 {
449  std::ofstream file(path.c_str(), std::ios::out | std::ios::binary);
450 
451  /* Initialize cpptoml root table
452  Note: Tables and fields are unordered
453  Ordering of table is maintained by directing the table
454  to the output stream immediately after completion */
455  auto root = cpptoml::make_table();
456 
457  /* Initialize bloom filter section and insert fields
458  and output to ostream */
459  auto header = cpptoml::make_table();
460  header->insert("bytes", bytes);
461  header->insert("hash_num", hash_num);
462  header->insert("counter_bits", size_t(sizeof(T) * CHAR_BIT));
463  root->insert(COUNTING_BLOOM_FILTER_MAGIC_HEADER, header);
464  file << *root << "[HeaderEnd]\n";
465 
466  file.write((char*)bytearray.data(), bytes);
467 }
468 
470  size_t bytes,
471  unsigned hash_num)
472  : k(k)
473  , bf(bytes, hash_num)
474 {}
475 
476 inline void
477 KmerBloomFilter::insert(const std::string& seq)
478 {
479  insert(seq.c_str(), seq.size());
480 }
481 
482 inline void
483 KmerBloomFilter::insert(const char* seq, size_t seq_len)
484 {
485  NtHash nthash(seq, seq_len, k, bf.get_hash_num());
486  while (nthash.roll()) {
487  bf.insert(nthash.hashes());
488  }
489 }
490 
491 inline unsigned
492 KmerBloomFilter::contains(const std::string& seq)
493 {
494  return contains(seq.c_str(), seq.size());
495 }
496 inline unsigned
497 KmerBloomFilter::contains(const char* seq, size_t seq_len)
498 {
499  unsigned count = 0;
500  NtHash nthash(seq, seq_len, k, bf.get_hash_num());
501  while (nthash.roll()) {
502  if (bf.contains(nthash.hashes())) {
503  count++;
504  }
505  }
506  return count;
507 }
508 
509 template<typename T>
511  size_t bytes,
512  unsigned hash_num)
513  : k(k)
514  , bf(bytes, hash_num)
515 {}
516 
517 template<typename T>
518 inline void
519 KmerCountingBloomFilter<T>::insert(const std::string& seq)
520 {
521  insert(seq.c_str(), seq.size());
522 }
523 
524 template<typename T>
525 inline void
526 KmerCountingBloomFilter<T>::insert(const char* seq, size_t seq_len)
527 {
528  NtHash nthash(seq, seq_len, k, bf.get_hash_num());
529  while (nthash.roll()) {
530  bf.insert(nthash.hashes());
531  }
532 }
533 
534 template<typename T>
535 inline uint64_t
536 KmerCountingBloomFilter<T>::contains(const std::string& seq)
537 {
538  return contains(seq.c_str(), seq.size());
539 }
540 
541 template<typename T>
542 inline uint64_t
543 KmerCountingBloomFilter<T>::contains(const char* seq, size_t seq_len)
544 {
545  uint64_t count = 0;
546  NtHash nthash(seq, seq_len, k, bf.get_hash_num());
547  while (nthash.roll()) {
548  count += bf.contains(nthash.hashes());
549  }
550  return count;
551 }
552 
553 } // namespace btllib
554 
555 #endif
btllib::KmerBloomFilter
Definition: bloom_filter.hpp:83
btllib::KmerBloomFilter::KmerBloomFilter
KmerBloomFilter(unsigned k, size_t bytes, unsigned hash_num=4)
Definition: bloom_filter.hpp:469
btllib::CountingBloomFilter::write
void write(const std::string &path)
Definition: bloom_filter.hpp:447
btllib::CountingBloomFilter
Definition: bloom_filter.hpp:47
btllib::KmerBloomFilter::insert
void insert(const std::string &seq)
Definition: bloom_filter.hpp:477
btllib::NtHash::roll
bool roll()
btllib::BloomFilter
Definition: bloom_filter.hpp:19
btllib::KmerCountingBloomFilter
Definition: bloom_filter.hpp:131
btllib::KmerBloomFilter::contains
unsigned contains(const std::string &seq)
Definition: bloom_filter.hpp:492
btllib::NtHash
Definition: nthash.hpp:977