29 namespace seqan3::detail
34 template <
typename derived_t,
typename edit_traits>
35 class edit_distance_unbanded_max_errors_policy :
41 static_assert(edit_traits::use_max_errors,
"This policy assumes that edit_traits::use_max_errors is true.");
49 edit_distance_unbanded_max_errors_policy() noexcept = default;
50 edit_distance_unbanded_max_errors_policy(edit_distance_unbanded_max_errors_policy const &) noexcept
52 edit_distance_unbanded_max_errors_policy(edit_distance_unbanded_max_errors_policy &&) noexcept
54 edit_distance_unbanded_max_errors_policy & operator=(edit_distance_unbanded_max_errors_policy const &) noexcept
56 edit_distance_unbanded_max_errors_policy & operator=(edit_distance_unbanded_max_errors_policy &&) noexcept
58 ~edit_distance_unbanded_max_errors_policy() noexcept = default;
61 using typename edit_traits::word_type;
62 using typename edit_traits::score_type;
68 score_type max_errors{255};
71 size_t last_block{0u};
73 word_type last_score_mask{};
80 void max_errors_init(
size_t block_count) noexcept
83 derived_t *
self =
static_cast<derived_t *
>(
this);
85 max_errors = -get<align_cfg::min_score>(self->config).score;
87 assert(max_errors >= score_type{0});
89 if (std::ranges::empty(self->query))
92 self->score_mask = 0u;
93 last_score_mask =
self->score_mask;
97 last_block = block_count - 1u;
98 last_score_mask =
self->score_mask;
103 size_t const local_max_errors = std::min<size_t>(max_errors,
std::ranges::size(self->query) - 1u);
104 self->score_mask = word_type{1u} << (local_max_errors %
self->word_size);
105 last_block =
std::min(local_max_errors / self->word_size, last_block);
106 self->_score = local_max_errors + 1u;
110 bool is_last_active_cell_within_last_row() const noexcept
112 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
113 return (self->score_mask == this->last_score_mask) && (this->last_block ==
self->vp.size() - 1u);
117 bool prev_last_active_cell() noexcept
119 derived_t *
self =
static_cast<derived_t *
>(
this);
120 self->score_mask >>= 1u;
121 if (self->score_mask != 0u)
124 if constexpr (edit_traits::is_global)
126 if (last_block == 0u)
132 self->score_mask = word_type{1u} << (edit_traits::word_size - 1u);
137 void next_last_active_cell() noexcept
139 derived_t *
self =
static_cast<derived_t *
>(
this);
140 self->score_mask <<= 1u;
141 if (self->score_mask)
144 self->score_mask = 1u;
151 bool update_last_active_cell() noexcept
153 derived_t *
self =
static_cast<derived_t *
>(
this);
155 while (!(self->_score <= max_errors))
157 self->advance_score(self->vn[last_block], self->vp[last_block], self->score_mask);
158 if (!prev_last_active_cell())
161 assert(edit_traits::is_global);
164 return !edit_traits::compute_matrix;
168 if (is_last_active_cell_within_last_row())
170 assert(self->_score <= max_errors);
172 if constexpr(edit_traits::is_semi_global)
173 self->update_best_score();
175 return self->on_hit();
179 next_last_active_cell();
180 self->advance_score(self->vp[last_block], self->vn[last_block], self->score_mask);
188 static size_t max_rows(word_type
const score_mask,
unsigned const last_block,
189 score_type
const score, score_type
const max_errors) noexcept
191 using score_matrix_type =
typename edit_traits::score_matrix_type;
192 return score_matrix_type::max_rows(score_mask,
202 template <
typename derived_t,
typename edit_traits>
203 class edit_distance_unbanded_global_policy :
209 static_assert(edit_traits::is_global || edit_traits::is_semi_global,
210 "This policy assumes that edit_traits::is_global or edit_traits::is_semi_global is true.");
218 edit_distance_unbanded_global_policy() noexcept = default;
219 edit_distance_unbanded_global_policy(edit_distance_unbanded_global_policy const &) noexcept
221 edit_distance_unbanded_global_policy(edit_distance_unbanded_global_policy &&) noexcept
223 edit_distance_unbanded_global_policy & operator=(edit_distance_unbanded_global_policy const &) noexcept
225 edit_distance_unbanded_global_policy & operator=(edit_distance_unbanded_global_policy &&) noexcept
227 ~edit_distance_unbanded_global_policy() noexcept = default;
231 using typename edit_traits::score_type;
241 score_type _best_score{};
248 void score_init() noexcept
251 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
252 _best_score =
self->_score;
256 bool is_valid() const noexcept
258 [[maybe_unused]] derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
263 if constexpr(edit_traits::use_max_errors)
264 return _best_score <=
self->max_errors;
272 alignment_coordinate invalid_coordinate() const noexcept
274 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
279 void update_best_score() noexcept
281 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
282 _best_score =
self->_score;
286 size_t end_positions_first() const noexcept
288 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
302 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
303 static_assert(edit_traits::compute_score,
"score() can only be computed if you specify the result type within "
304 "your alignment config.");
305 if (!self->is_valid())
313 alignment_coordinate end_positions() const noexcept
315 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
316 static_assert(edit_traits::compute_end_positions,
"end_positions() can only be computed if you specify the "
317 "result type within your alignment config.");
318 if (!self->is_valid())
319 return self->invalid_coordinate();
321 column_index_type
const first{
self->end_positions_first()};
323 return {first, second};
329 template <
typename derived_t,
typename edit_traits>
330 class edit_distance_unbanded_semi_global_policy :
331 public edit_distance_unbanded_global_policy<derived_t, edit_traits>
334 static_assert(edit_traits::is_semi_global,
"This policy assumes that edit_traits::is_semi_global is true.");
342 edit_distance_unbanded_semi_global_policy() noexcept = default;
343 edit_distance_unbanded_semi_global_policy(edit_distance_unbanded_semi_global_policy const &) noexcept
345 edit_distance_unbanded_semi_global_policy(edit_distance_unbanded_semi_global_policy &&) noexcept
347 edit_distance_unbanded_semi_global_policy & operator=(edit_distance_unbanded_semi_global_policy const &) noexcept
349 edit_distance_unbanded_semi_global_policy & operator=(edit_distance_unbanded_semi_global_policy &&) noexcept
351 ~edit_distance_unbanded_semi_global_policy() noexcept = default;
355 using base_t = edit_distance_unbanded_global_policy<derived_t, edit_traits>;
357 using database_iterator = typename edit_traits::database_iterator;
358 using base_t::_best_score;
372 database_iterator _best_score_col{};
379 void score_init() noexcept
382 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
383 base_t::score_init();
384 _best_score_col =
self->database_it_end;
388 void update_best_score() noexcept
390 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
392 if constexpr(edit_traits::use_max_errors)
394 assert(std::ranges::empty(self->query) || self->is_last_active_cell_within_last_row());
397 _best_score_col = (
self->_score <= _best_score) ? self->database_it : _best_score_col;
398 _best_score = (self->_score <= _best_score) ?
self->_score : _best_score;
402 size_t end_positions_first() const noexcept
404 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
406 size_t offset = std::ranges::empty(self->database) ? 0u : 1u;
407 return std::ranges::distance(std::ranges::begin(self->database), _best_score_col) + offset;
415 template <
typename derived_t,
typename edit_traits>
416 class edit_distance_unbanded_score_matrix_policy :
422 static_assert(edit_traits::compute_score_matrix,
423 "This policy assumes that edit_traits::compute_score_matrix is true.");
431 edit_distance_unbanded_score_matrix_policy() noexcept = default;
432 edit_distance_unbanded_score_matrix_policy(edit_distance_unbanded_score_matrix_policy const &) noexcept
434 edit_distance_unbanded_score_matrix_policy(edit_distance_unbanded_score_matrix_policy &&) noexcept
436 edit_distance_unbanded_score_matrix_policy & operator=(edit_distance_unbanded_score_matrix_policy const &) noexcept
438 edit_distance_unbanded_score_matrix_policy & operator=(edit_distance_unbanded_score_matrix_policy &&) noexcept
440 ~edit_distance_unbanded_score_matrix_policy() noexcept = default;
443 using typename edit_traits::score_matrix_type;
449 score_matrix_type _score_matrix{};
465 void score_matrix_init()
467 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
482 score_matrix_type
const & score_matrix() const noexcept
484 static_assert(edit_traits::compute_score_matrix,
"score_matrix() can only be computed if you specify the "
485 "result type within your alignment config.");
486 return _score_matrix;
494 template <
typename derived_t,
typename edit_traits>
495 class edit_distance_unbanded_trace_matrix_policy :
501 static_assert(edit_traits::compute_trace_matrix,
502 "This policy assumes that edit_traits::compute_trace_matrix is true.");
510 edit_distance_unbanded_trace_matrix_policy() noexcept = default;
511 edit_distance_unbanded_trace_matrix_policy(edit_distance_unbanded_trace_matrix_policy const &) noexcept
513 edit_distance_unbanded_trace_matrix_policy(edit_distance_unbanded_trace_matrix_policy &&) noexcept
515 edit_distance_unbanded_trace_matrix_policy & operator=(edit_distance_unbanded_trace_matrix_policy const &) noexcept
517 edit_distance_unbanded_trace_matrix_policy & operator=(edit_distance_unbanded_trace_matrix_policy &&) noexcept
519 ~edit_distance_unbanded_trace_matrix_policy() noexcept = default;
522 using typename edit_traits::word_type;
523 using typename edit_traits::trace_matrix_type;
524 using typename edit_traits::alignment_result_type;
530 std::vector<word_type> hp{};
536 trace_matrix_type _trace_matrix{};
551 void trace_matrix_init(
size_t block_count)
553 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
558 hp.resize(block_count, 0u);
559 db.resize(block_count, 0u);
568 trace_matrix_type
const & trace_matrix() const noexcept
571 static_assert(edit_traits::compute_trace_matrix,
"trace_matrix() can only be computed if you specify the "
572 "result type within your alignment config.");
573 return _trace_matrix;
578 alignment_coordinate begin_positions() const noexcept
580 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
581 static_assert(edit_traits::compute_begin_positions,
"begin_positions() can only be computed if you specify the "
582 "result type within your alignment config.");
583 if (!self->is_valid())
584 return self->invalid_coordinate();
586 alignment_coordinate
const back =
self->end_positions();
587 return alignment_begin_positions(trace_matrix(),
back);
592 auto alignment() const noexcept
596 derived_t
const *
self =
static_cast<derived_t
const *
>(
this);
597 static_assert(edit_traits::compute_sequence_alignment,
"alignment() can only be computed if you specify the "
598 "result type within your alignment config.");
600 if (!self->is_valid())
601 return alignment_t{};
603 return alignment_trace<alignment_t>(self->database,
606 self->end_positions(),
615 template <
typename value_t>
616 class proxy_reference
622 proxy_reference() noexcept = default;
623 proxy_reference(proxy_reference const &) noexcept = default;
624 proxy_reference(proxy_reference &&) noexcept = default;
625 proxy_reference & operator=(proxy_reference const &) noexcept = default;
626 proxy_reference & operator=(proxy_reference &&) noexcept = default;
627 ~proxy_reference() = default;
630 proxy_reference(value_t & t) noexcept
631 : ptr(
std::addressof(t))
634 proxy_reference(value_t &&) =
delete;
637 template <
typename other_value_t>
639 requires std::convertible_to<other_value_t, value_t>
641 proxy_reference & operator=(other_value_t && u) noexcept
643 get() = std::forward<other_value_t>(u);
649 value_t &
get() const noexcept
651 assert(ptr !=
nullptr);
656 operator value_t & ()
const noexcept
663 value_t * ptr{
nullptr};
677 template <std::ranges::viewable_range database_t,
678 std::ranges::viewable_range query_t,
679 typename align_config_t,
680 typename edit_traits>
681 class edit_distance_unbanded :
684 public edit_distance_base<
685 edit_traits::use_max_errors,
686 edit_distance_unbanded_max_errors_policy,
688 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
689 public edit_distance_base<
690 edit_traits::is_global,
691 edit_distance_unbanded_global_policy,
693 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
694 public edit_distance_base<
695 edit_traits::is_semi_global,
696 edit_distance_unbanded_semi_global_policy,
698 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
699 public edit_distance_base<
700 edit_traits::compute_score_matrix,
701 edit_distance_unbanded_score_matrix_policy,
703 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
704 public edit_distance_base<
705 edit_traits::compute_trace_matrix,
706 edit_distance_unbanded_trace_matrix_policy,
708 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>
712 using typename edit_traits::word_type;
713 using typename edit_traits::score_type;
714 using typename edit_traits::database_type;
715 using typename edit_traits::query_type;
716 using typename edit_traits::align_config_type;
717 using edit_traits::word_size;
721 template <
typename other_derived_t,
typename other_edit_traits>
722 friend class edit_distance_unbanded_max_errors_policy;
724 template <
typename other_derived_t,
typename other_edit_traits>
725 friend class edit_distance_unbanded_global_policy;
727 template <
typename other_derived_t,
typename other_edit_traits>
728 friend class edit_distance_unbanded_semi_global_policy;
730 template <
typename other_derived_t,
typename other_edit_traits>
731 friend class edit_distance_unbanded_score_matrix_policy;
733 template <
typename other_derived_t,
typename other_edit_traits>
734 friend class edit_distance_unbanded_trace_matrix_policy;
736 using typename edit_traits::database_iterator;
737 using typename edit_traits::query_alphabet_type;
738 using typename edit_traits::alignment_result_type;
739 using edit_traits::use_max_errors;
740 using edit_traits::is_semi_global;
741 using edit_traits::is_global;
742 using edit_traits::compute_score;
743 using edit_traits::compute_end_positions;
744 using edit_traits::compute_begin_positions;
745 using edit_traits::compute_sequence_alignment;
746 using edit_traits::compute_score_matrix;
747 using edit_traits::compute_trace_matrix;
748 using edit_traits::compute_matrix;
749 using typename edit_traits::score_matrix_type;
750 using typename edit_traits::trace_matrix_type;
757 align_config_t config;
760 static constexpr word_type hp0 = is_global ? 1u : 0u;
762 static constexpr word_type hn0 = 0u;
764 static constexpr word_type vp0 = ~word_type{0u};
766 static constexpr word_type vn0 = 0u;
776 word_type score_mask{0u};
790 database_iterator database_it{};
792 database_iterator database_it_end{};
795 struct compute_state_trace_matrix
798 proxy_reference<word_type> db{};
802 struct compute_state : enable_state_t<compute_trace_matrix, compute_state_trace_matrix>
816 proxy_reference<word_type> vp{};
818 proxy_reference<word_type> vn{};
820 word_type carry_d0{};
822 word_type carry_hp{hp0};
824 word_type carry_hn{};
830 if constexpr(!use_max_errors && compute_score_matrix)
831 this->_score_matrix.add_column(vp, vn);
833 if constexpr(!use_max_errors && compute_trace_matrix)
834 this->_trace_matrix.add_column(this->hp, this->db, vp);
836 if constexpr(use_max_errors && compute_matrix)
838 size_t max_rows = this->max_rows(score_mask, this->last_block, _score, this->max_errors);
839 if constexpr(compute_score_matrix)
840 this->_score_matrix.add_column(vp, vn, max_rows);
842 if constexpr(compute_trace_matrix)
843 this->_trace_matrix.add_column(this->hp, this->db, vp, max_rows);
851 edit_distance_unbanded() =
delete;
853 edit_distance_unbanded(edit_distance_unbanded
const &) =
default;
854 edit_distance_unbanded(edit_distance_unbanded &&) =
default;
855 edit_distance_unbanded & operator=(edit_distance_unbanded
const &) =
default;
856 edit_distance_unbanded & operator=(edit_distance_unbanded &&) =
default;
857 ~edit_distance_unbanded() =
default;
865 edit_distance_unbanded(database_t _database,
867 align_config_t _config,
868 edit_traits
const & SEQAN3_DOXYGEN_ONLY(_traits)) :
869 database{std::forward<database_t>(_database)},
870 query{std::forward<query_t>(_query)},
871 config{std::forward<align_config_t>(_config)},
873 database_it{ranges::begin(database)},
874 database_it_end{ranges::end(database)}
876 static constexpr
size_t alphabet_size_ = alphabet_size<query_alphabet_type>;
878 size_t const block_count = (
std::ranges::size(query) - 1u + word_size) / word_size;
879 score_mask = word_type{1u} << ((
std::ranges::size(query) - 1u + word_size) % word_size);
882 if constexpr(use_max_errors)
883 this->max_errors_init(block_count);
885 if constexpr(compute_score_matrix)
886 this->score_matrix_init();
888 if constexpr(compute_trace_matrix)
889 this->trace_matrix_init(block_count);
891 vp.resize(block_count, vp0);
892 vn.resize(block_count, vn0);
893 bit_masks.resize((alphabet_size_ + 1u) * block_count, 0u);
898 size_t const i = block_count *
seqan3::to_rank(query[j]) + j / word_size;
899 bit_masks[i] |= word_type{1u} << (j % word_size);
908 template <
bool with_carry>
909 static void compute_step(compute_state & state) noexcept
912 assert(state.carry_d0 <= 1u);
913 assert(state.carry_hp <= 1u);
914 assert(state.carry_hn <= 1u);
916 x = state.b | state.vn;
917 t = state.vp + (x & state.vp) + state.carry_d0;
919 state.d0 = (t ^ state.vp) | x;
920 state.hn = state.vp & state.d0;
921 state.hp = state.vn | ~(state.vp | state.d0);
923 if constexpr(with_carry)
924 state.carry_d0 = (state.carry_d0 != 0u) ? t <= state.vp : t < state.vp;
926 x = (state.hp << 1u) | state.carry_hp;
927 state.vn = x & state.d0;
928 state.vp = (state.hn << 1u) | ~(x | state.d0) | state.carry_hn;
930 if constexpr(with_carry)
932 state.carry_hp = state.hp >> (word_size - 1u);
933 state.carry_hn = state.hn >> (word_size - 1u);
938 template <
bool with_carry>
939 void compute_kernel(compute_state & state,
size_t const block_offset,
size_t const current_block) noexcept
941 state.vp = proxy_reference<word_type>{this->vp[current_block]};
942 state.vn = proxy_reference<word_type>{this->vn[current_block]};
943 if constexpr(compute_trace_matrix)
945 state.hp = proxy_reference<word_type>{this->hp[current_block]};
946 state.db = proxy_reference<word_type>{this->db[current_block]};
948 state.b = bit_masks[block_offset + current_block];
950 compute_step<with_carry>(state);
951 if constexpr(compute_trace_matrix)
952 state.db = ~(state.b ^ state.d0);
956 void advance_score(word_type P, word_type N, word_type mask) noexcept
958 if ((P & mask) != word_type{0u})
960 else if ((N & mask) != word_type{0u})
965 bool on_hit() noexcept
972 inline bool small_patterns();
975 inline bool large_patterns();
978 inline void compute_empty_query_sequence()
980 assert(std::ranges::empty(query));
982 bool abort_computation =
false;
984 for (; database_it != database_it_end; ++database_it)
986 if constexpr(is_global)
989 this->update_best_score();
992 if constexpr(use_max_errors)
993 abort_computation = on_hit();
996 if (abort_computation)
1005 if constexpr(use_max_errors && is_global && !compute_matrix)
1025 if (vp.size() == 0u)
1026 compute_empty_query_sequence();
1027 else if (vp.size() == 1u)
1032 if constexpr(is_global)
1033 this->update_best_score();
1041 template <
typename callback_t>
1042 void operator()([[maybe_unused]]
size_t const idx, callback_t && callback)
1044 using traits_type = alignment_configuration_traits<align_config_t>;
1045 using result_value_type =
typename alignment_result_value_type_accessor<alignment_result_type>::type;
1052 auto cached_end_positions = this->invalid_coordinate();
1053 auto cached_begin_positions = this->invalid_coordinate();
1055 if constexpr (compute_end_positions)
1056 cached_end_positions = this->end_positions();
1058 if constexpr (compute_begin_positions)
1060 static_assert(compute_end_positions,
"End positions required to compute the begin positions.");
1061 if (this->is_valid())
1062 cached_begin_positions = alignment_begin_positions(this->trace_matrix(), cached_end_positions);
1068 result_value_type res_vt{};
1070 if constexpr (traits_type::output_sequence1_id)
1071 res_vt.sequence1_id = idx;
1073 if constexpr (traits_type::output_sequence2_id)
1074 res_vt.sequence2_id = idx;
1076 if constexpr (traits_type::compute_score)
1077 res_vt.score = this->score().value_or(matrix_inf<score_type>);
1079 if constexpr (traits_type::compute_sequence_alignment)
1081 if (this->is_valid())
1083 using alignment_t = decltype(res_vt.alignment);
1084 res_vt.alignment = alignment_trace<alignment_t>(database,
1086 this->trace_matrix(),
1087 cached_end_positions,
1088 cached_begin_positions);
1092 if constexpr (traits_type::compute_end_positions)
1093 res_vt.end_positions =
std::move(cached_end_positions);
1095 if constexpr (traits_type::compute_begin_positions)
1096 res_vt.begin_positions =
std::move(cached_begin_positions);
1098 callback(alignment_result_type{
std::move(res_vt)});
1102 template <
typename database_t,
typename query_t,
typename align_config_t,
typename traits_t>
1103 bool edit_distance_unbanded<database_t, query_t, align_config_t, traits_t>::small_patterns()
1105 bool abort_computation =
false;
1108 while (database_it != database_it_end)
1110 compute_state state{};
1111 size_t const block_offset =
seqan3::to_rank((query_alphabet_type) *database_it);
1113 compute_kernel<false>(state, block_offset, 0u);
1114 advance_score(state.hp, state.hn, score_mask);
1117 if constexpr(is_semi_global && !use_max_errors)
1118 this->update_best_score();
1121 if constexpr(use_max_errors)
1122 abort_computation = this->update_last_active_cell();
1126 if (abort_computation)
1133 template <
typename database_t,
typename query_t,
typename align_config_t,
typename traits_t>
1134 bool edit_distance_unbanded<database_t, query_t, align_config_t, traits_t>::large_patterns()
1136 bool abort_computation =
false;
1138 while (database_it != database_it_end)
1140 compute_state state{};
1141 size_t const block_offset = vp.size() *
seqan3::to_rank((query_alphabet_type) *database_it);
1143 size_t block_count = vp.size();
1144 if constexpr(use_max_errors)
1145 block_count = this->last_block + 1;
1148 for (
size_t current_block = 0u; current_block < block_count; current_block++)
1149 compute_kernel<true>(state, block_offset, current_block);
1151 advance_score(state.hp, state.hn, score_mask);
1154 if constexpr(is_semi_global && !use_max_errors)
1155 this->update_best_score();
1157 if constexpr(use_max_errors)
1160 bool additional_block = score_mask >> (word_size - 1u);
1161 bool reached_last_block = this->last_block + 1u == vp.size();
1163 if (reached_last_block)
1164 additional_block =
false;
1166 if (additional_block)
1168 size_t const current_block = this->last_block + 1u;
1170 vp[current_block] = vp0;
1171 vn[current_block] = vn0;
1172 compute_kernel<false>(state, block_offset, current_block);
1176 abort_computation = this->update_last_active_cell();
1182 if (abort_computation)
1194 template <
typename database_t,
typename query_t,
typename config_t,
typename traits_t>
1196 edit_distance_unbanded(database_t && database, query_t && query, config_t config, traits_t)
1197 -> edit_distance_unbanded<database_t, query_t, config_t, traits_t>;