HIBF 1.0.0-rc.1
All Classes Namespaces Files Functions Variables Typedefs Friends Macros Modules Pages Concepts
counting_vector.hpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2006-2025, Knut Reinert & Freie Universität Berlin
2// SPDX-FileCopyrightText: 2016-2025, Knut Reinert & MPI für molekulare Genetik
3// SPDX-License-Identifier: BSD-3-Clause
4
10#pragma once
11
12#include <algorithm> // for transform
13#include <bit> // for countr_zero
14#include <cassert> // for assert
15#include <climits> // for CHAR_BIT
16#include <concepts> // for integral
17#include <cstdint> // for uint64_t, uint8_t
18#include <cstring> // for size_t
19#include <functional> // for minus, plus
20#include <type_traits> // for conditional, conditional_t
21#include <vector> // for vector
22
23#include <hibf/contrib/aligned_allocator.hpp> // for aligned_allocator
24#include <hibf/misc/bit_vector.hpp> // for bit_vector
25#include <hibf/misc/divide_and_ceil.hpp> // for divide_and_ceil
26#include <hibf/misc/next_multiple_of_64.hpp> // for next_multiple_of_64
27#include <hibf/platform.hpp> // for HIBF_HAS_AVX512
28
29#if HIBF_HAS_AVX512
30# include <simde/x86/avx512/add.h> // for simde_mm512_add_epi16, simde_mm512_add_epi32, simde_mm512_add_...
31# include <simde/x86/avx512/load.h> // for simde_mm512_load_si512
32# include <simde/x86/avx512/mov.h> // for simde_mm512_maskz_mov_epi16, simde_mm512_maskz_mov_epi32, simd...
33# include <simde/x86/avx512/set1.h> // for simde_mm512_set1_epi16, simde_mm512_set1_epi32, simde_mm512_se...
34# include <simde/x86/avx512/store.h> // for simde_mm512_store_si512
35# include <simde/x86/avx512/sub.h> // for simde_mm512_sub_epi16, simde_mm512_sub_epi32, simde_mm512_sub_...
36# include <simde/x86/avx512/types.h> // for simde__m512i
37#endif
38
39namespace seqan::hibf
40{
41
42#if HIBF_HAS_AVX512
44// Since the counting_vector can have different value types, we need specific SIMD instructions for each value type.
45template <std::integral integral_t>
46struct simd_mapping
47{};
48
49// CRTP base class for the simd_mapping, containg common functionality.
50template <typename derived_t, std::integral integral_t>
52struct simd_mapping_crtp
53{
54 // Let `B = sizeof(integral_t) * CHAR_BIT`, e.g. 8 for (u)int8_t, and 16 for (u)int16_t.
55 // We can process `512 / B` bits of the bit_vector at once.
56 static inline constexpr size_t bits_per_iterations = 512u / (sizeof(integral_t) * CHAR_BIT);
57 // clang-format off
58 // The type that is need to represent `bits_per_iterations` bits.
59 using bits_type = std::conditional_t<bits_per_iterations == 64, uint64_t,
60 std::conditional_t<bits_per_iterations == 32, uint32_t,
61 std::conditional_t<bits_per_iterations == 16, uint16_t,
63 // clang-format on
64 static_assert(!std::same_as<bits_type, void>, "Unsupported bits_type.");
65
66 // Takes B bits from the bit_vector and expands them to a bits_type.
67 // E.g., B = 8 : [1,0,1,1,0,0,1,0] -> [0...01, 0...00, 0...01, 0...01, 0...00, 0...00, 0...01, 0...00], where
68 // each element is 64 bits wide.
69 static inline constexpr auto expand_bits(bits_type const * const bits)
70 {
71 return derived_t::mm512_maskz_mov_epi(*bits, derived_t::mm512_set1_epi(1));
72 }
73};
74
75// SIMD instructions for int8_t and uint8_t.
76template <std::integral integral_t>
77 requires (sizeof(integral_t) == 1)
78struct simd_mapping<integral_t> : simd_mapping_crtp<simd_mapping<integral_t>, integral_t>
79{
80 static inline constexpr auto mm512_maskz_mov_epi = simde_mm512_maskz_mov_epi8;
81 static inline constexpr auto mm512_set1_epi = simde_mm512_set1_epi8;
82 static inline constexpr auto mm512_add_epi = simde_mm512_add_epi8;
83 static inline constexpr auto mm512_sub_epi = simde_mm512_sub_epi8;
84};
85
86// SIMD instructions for int16_t and uint16_t.
87template <std::integral integral_t>
88 requires (sizeof(integral_t) == 2)
89struct simd_mapping<integral_t> : simd_mapping_crtp<simd_mapping<integral_t>, integral_t>
90{
91 static inline constexpr auto mm512_maskz_mov_epi = simde_mm512_maskz_mov_epi16;
92 static inline constexpr auto mm512_set1_epi = simde_mm512_set1_epi16;
93 static inline constexpr auto mm512_add_epi = simde_mm512_add_epi16;
94 static inline constexpr auto mm512_sub_epi = simde_mm512_sub_epi16;
95};
96
97// SIMD instructions for int32_t and uint32_t.
98template <std::integral integral_t>
99 requires (sizeof(integral_t) == 4)
100struct simd_mapping<integral_t> : simd_mapping_crtp<simd_mapping<integral_t>, integral_t>
101{
102 static inline constexpr auto mm512_maskz_mov_epi = simde_mm512_maskz_mov_epi32;
103 static inline constexpr auto mm512_set1_epi = simde_mm512_set1_epi32;
104 static inline constexpr auto mm512_add_epi = simde_mm512_add_epi32;
105 static inline constexpr auto mm512_sub_epi = simde_mm512_sub_epi32;
106};
107
108// SIMD instructions for int64_t and uint64_t.
109template <std::integral integral_t>
110 requires (sizeof(integral_t) == 8)
111struct simd_mapping<integral_t> : simd_mapping_crtp<simd_mapping<integral_t>, integral_t>
112{
113 static inline constexpr auto mm512_maskz_mov_epi = simde_mm512_maskz_mov_epi64;
114 static inline constexpr auto mm512_set1_epi = simde_mm512_set1_epi64;
115 static inline constexpr auto mm512_add_epi = simde_mm512_add_epi64;
116 static inline constexpr auto mm512_sub_epi = simde_mm512_sub_epi64;
117};
119#endif
120
144template <std::integral value_t>
145class counting_vector : public std::vector<value_t, seqan::hibf::contrib::aligned_allocator<value_t, 64u>>
146{
147private:
150
151public:
155 counting_vector() = default;
156 counting_vector(counting_vector const &) = default;
160 ~counting_vector() = default;
161
162 using base_t::base_t;
164
173 {
174 impl<operation::add>(bit_vector);
175 return *this;
176 }
177
189 {
190 impl<operation::sub>(bit_vector);
191 return *this;
192 }
193
202 {
203 assert(this->size() >= rhs.size()); // The counting vector may be bigger than what we need.
204
205 std::transform(this->begin(), this->end(), rhs.begin(), this->begin(), std::plus<value_t>());
206
207 return *this;
208 }
209
215 {
216 assert(this->size() >= rhs.size()); // The counting vector may be bigger than what we need.
217
218 std::transform(this->begin(), this->end(), rhs.begin(), this->begin(), std::minus<value_t>());
219
220 return *this;
221 }
222
223private:
225 enum class operation : uint8_t
226 {
227 add,
228 sub
229 };
230
232 template <operation op>
233 inline void impl(bit_vector const & bit_vector)
234 {
235 assert(this->size() >= bit_vector.size()); // The counting vector may be bigger than what we need.
236#if HIBF_HAS_AVX512
237 // AVX512BW: mm512_maskz_mov_epi, mm512_add_epi
238 // AVX512F: mm512_set1_epi, _mm512_load_si512, _mm512_store_si512
239 using simd = simd_mapping<value_t>;
240 using bits_type = typename simd::bits_type;
241
242 bits_type const * bit_vector_ptr = reinterpret_cast<bits_type const *>(bit_vector.data());
243 value_t * counting_vector_ptr = base_t::data();
244
245 size_t const bits = next_multiple_of_64(bit_vector.size());
246 assert(bits <= this->capacity()); // Not enough memory reserved for AVX512 chunk access.
247 size_t const iterations = bits / simd::bits_per_iterations;
248
249 for (size_t iteration = 0; iteration < iterations;
250 ++iteration, ++bit_vector_ptr, counting_vector_ptr += simd::bits_per_iterations)
251 {
252 simde__m512i load = simde_mm512_load_si512(counting_vector_ptr);
253 if constexpr (op == operation::add)
254 {
255 load = simd::mm512_add_epi(load, simd::expand_bits(bit_vector_ptr));
256 }
257 else
258 {
259 load = simd::mm512_sub_epi(load, simd::expand_bits(bit_vector_ptr));
260 }
261 simde_mm512_store_si512(counting_vector_ptr, load);
262 }
263#else
264 size_t const words = divide_and_ceil(bit_vector.size(), 64u);
265 uint64_t const * const bit_vector_ptr = bit_vector.data();
266
267 // Jump to the next 1 and return the number of jumped bits in value.
268 auto jump_to_next_1bit = [](uint64_t & value)
269 {
270 auto const zeros = std::countr_zero(value);
271 value >>= zeros; // skip number of zeros
272 return zeros;
273 };
274
275 // Each iteration can handle 64 bits, i.e., one word.
276 for (size_t iteration = 0; iteration < words; ++iteration)
277 {
278 uint64_t current_word = bit_vector_ptr[iteration];
279
280 // For each set bit in the current word, add/subtract 1 to the corresponding bin.
281 for (size_t bin = iteration * 64u; current_word != 0u; ++bin, current_word >>= 1)
282 {
283 // Jump to the next 1
284 bin += jump_to_next_1bit(current_word);
285
286 if constexpr (op == operation::add)
287 {
288 ++(*this)[bin];
289 }
290 else
291 {
292 --(*this)[bin];
293 }
294 }
295 }
296#endif
297 }
298};
299
300} // namespace seqan::hibf
Provides seqan::hibf::bit_vector.
An bit vector.
Definition bit_vector.hpp:68
A data structure that behaves like a std::vector and can be used to consolidate the results of multip...
Definition counting_vector.hpp:146
counting_vector & operator+=(counting_vector const &rhs)
Bin-wise addition of two seqan::hibf::counting_vectors.
Definition counting_vector.hpp:201
counting_vector & operator+=(bit_vector const &bit_vector)
Bin-wise adds the bits of a seqan::hibf::bit_vector.
Definition counting_vector.hpp:172
counting_vector(counting_vector const &)=default
Defaulted.
counting_vector & operator-=(bit_vector const &bit_vector)
Bin-wise subtracts the bits of a seqan::hibf::bit_vector.
Definition counting_vector.hpp:188
counting_vector & operator-=(counting_vector const &rhs)
Bin-wise subtraction of two seqan::hibf::counting_vectors.
Definition counting_vector.hpp:214
~counting_vector()=default
Defaulted.
counting_vector & operator=(counting_vector &&)=default
Defaulted.
counting_vector(counting_vector &&)=default
Defaulted.
counting_vector & operator=(counting_vector const &)=default
Defaulted.
counting_vector()=default
Defaulted.
T countr_zero(T... args)
constexpr size_t next_multiple_of_64(size_t const value) noexcept
Returns the smallest integer that is greater or equal to value and a multiple of 64....
Definition next_multiple_of_64.hpp:18
constexpr size_t divide_and_ceil(t1 const dividend, t2 const divisor) noexcept
Returns, for unsigned integral operands, dividend / divisor ceiled to the next integer value.
Definition divide_and_ceil.hpp:21
T is_base_of_v
Provides platform and dependency checks.
T transform(T... args)