19 #include <range/v3/algorithm/copy.hpp>
31 namespace seqan3::detail
36 template <
typename derived_t,
typename edit_traits>
37 class edit_distance_unbanded_max_errors_policy :
43 static_assert(edit_traits::use_max_errors,
"This policy assumes that edit_traits::use_max_errors is true.");
51 edit_distance_unbanded_max_errors_policy() noexcept = default;
52 edit_distance_unbanded_max_errors_policy(edit_distance_unbanded_max_errors_policy const &) noexcept
54 edit_distance_unbanded_max_errors_policy(edit_distance_unbanded_max_errors_policy &&) noexcept
56 edit_distance_unbanded_max_errors_policy & operator=(edit_distance_unbanded_max_errors_policy const &) noexcept
58 edit_distance_unbanded_max_errors_policy & operator=(edit_distance_unbanded_max_errors_policy &&) noexcept
60 ~edit_distance_unbanded_max_errors_policy() noexcept = default;
63 using typename edit_traits::word_type;
64 using typename edit_traits::score_type;
70 score_type max_errors{255};
73 size_t last_block{0u};
75 word_type last_score_mask{};
82 void max_errors_init(
size_t block_count) noexcept
85 derived_t *
self = static_cast<derived_t *>(
this);
87 max_errors = get<align_cfg::max_error>(self->config).value;
88 assert(max_errors >= score_type{0});
90 if (std::ranges::empty(self->query))
93 self->score_mask = 0u;
94 last_score_mask =
self->score_mask;
98 last_block = block_count - 1u;
99 last_score_mask =
self->score_mask;
104 size_t const local_max_errors = std::min<size_t>(max_errors,
std::ranges::size(self->query) - 1u);
105 self->score_mask = word_type{1u} << (local_max_errors %
self->word_size);
106 last_block =
std::min(local_max_errors / self->word_size, last_block);
107 self->_score = local_max_errors + 1u;
111 bool is_last_active_cell_within_last_row() const noexcept
113 derived_t
const *
self = static_cast<derived_t const *>(
this);
114 return (self->score_mask == this->last_score_mask) && (this->last_block ==
self->vp.size() - 1u);
118 bool prev_last_active_cell() noexcept
120 derived_t *
self = static_cast<derived_t *>(
this);
121 self->score_mask >>= 1u;
122 if (self->score_mask != 0u)
125 if constexpr (edit_traits::is_global)
127 if (last_block == 0u)
133 self->score_mask = word_type{1u} << (edit_traits::word_size - 1u);
138 void next_last_active_cell() noexcept
140 derived_t *
self = static_cast<derived_t *>(
this);
141 self->score_mask <<= 1u;
142 if (self->score_mask)
145 self->score_mask = 1u;
152 bool update_last_active_cell() noexcept
154 derived_t *
self = static_cast<derived_t *>(
this);
156 while (!(self->_score <= max_errors))
158 self->advance_score(self->vn[last_block], self->vp[last_block], self->score_mask);
159 if (!prev_last_active_cell())
162 assert(edit_traits::is_global);
165 return !edit_traits::compute_matrix;
169 if (is_last_active_cell_within_last_row())
171 assert(self->_score <= max_errors);
173 if constexpr(edit_traits::is_semi_global)
174 self->update_best_score();
176 return self->on_hit();
180 next_last_active_cell();
181 self->advance_score(self->vp[last_block], self->vn[last_block], self->score_mask);
189 static size_t max_rows(word_type
const score_mask,
unsigned const last_block,
190 score_type
const score, score_type
const max_errors) noexcept
192 using score_matrix_type =
typename edit_traits::score_matrix_type;
193 return score_matrix_type::max_rows(score_mask,
203 template <
typename derived_t,
typename edit_traits>
204 class edit_distance_unbanded_global_policy :
210 static_assert(edit_traits::is_global || edit_traits::is_semi_global,
211 "This policy assumes that edit_traits::is_global or edit_traits::is_semi_global is true.");
219 edit_distance_unbanded_global_policy() noexcept = default;
220 edit_distance_unbanded_global_policy(edit_distance_unbanded_global_policy const &) noexcept
222 edit_distance_unbanded_global_policy(edit_distance_unbanded_global_policy &&) noexcept
224 edit_distance_unbanded_global_policy & operator=(edit_distance_unbanded_global_policy const &) noexcept
226 edit_distance_unbanded_global_policy & operator=(edit_distance_unbanded_global_policy &&) noexcept
228 ~edit_distance_unbanded_global_policy() noexcept = default;
232 using typename edit_traits::score_type;
242 score_type _best_score{};
249 void score_init() noexcept
252 derived_t
const *
self = static_cast<derived_t const *>(
this);
253 _best_score =
self->_score;
257 bool is_valid() const noexcept
259 [[maybe_unused]] derived_t
const *
self = static_cast<derived_t const *>(
this);
264 if constexpr(edit_traits::use_max_errors)
265 return _best_score <=
self->max_errors;
273 alignment_coordinate invalid_coordinate() const noexcept
275 derived_t
const *
self = static_cast<derived_t const *>(
this);
280 void update_best_score() noexcept
282 derived_t
const *
self = static_cast<derived_t const *>(
this);
283 _best_score =
self->_score;
287 size_t back_coordinate_first() const noexcept
289 derived_t
const *
self = static_cast<derived_t const *>(
this);
303 derived_t
const *
self = static_cast<derived_t const *>(
this);
304 static_assert(edit_traits::compute_score,
"score() can only be computed if you specify the result type within "
305 "your alignment config.");
306 if (!self->is_valid())
314 alignment_coordinate back_coordinate() const noexcept
316 derived_t
const *
self = static_cast<derived_t const *>(
this);
317 static_assert(edit_traits::compute_back_coordinate,
"back_coordinate() can only be computed if you specify the"
318 "result type within your alignment config.");
319 if (!self->is_valid())
320 return self->invalid_coordinate();
322 column_index_type
const first{
self->back_coordinate_first()};
324 return {first, second};
330 template <
typename derived_t,
typename edit_traits>
331 class edit_distance_unbanded_semi_global_policy :
332 public edit_distance_unbanded_global_policy<derived_t, edit_traits>
335 static_assert(edit_traits::is_semi_global,
"This policy assumes that edit_traits::is_semi_global is true.");
343 edit_distance_unbanded_semi_global_policy() noexcept = default;
344 edit_distance_unbanded_semi_global_policy(edit_distance_unbanded_semi_global_policy const &) noexcept
346 edit_distance_unbanded_semi_global_policy(edit_distance_unbanded_semi_global_policy &&) noexcept
348 edit_distance_unbanded_semi_global_policy & operator=(edit_distance_unbanded_semi_global_policy const &) noexcept
350 edit_distance_unbanded_semi_global_policy & operator=(edit_distance_unbanded_semi_global_policy &&) noexcept
352 ~edit_distance_unbanded_semi_global_policy() noexcept = default;
356 using base_t = edit_distance_unbanded_global_policy<derived_t, edit_traits>;
358 using database_iterator = typename edit_traits::database_iterator;
359 using base_t::_best_score;
373 database_iterator _best_score_col{};
380 void score_init() noexcept
383 derived_t
const *
self = static_cast<derived_t const *>(
this);
384 base_t::score_init();
385 _best_score_col =
self->database_it_end;
389 void update_best_score() noexcept
391 derived_t
const *
self = static_cast<derived_t const *>(
this);
393 if constexpr(edit_traits::use_max_errors)
395 assert(std::ranges::empty(self->query) || self->is_last_active_cell_within_last_row());
398 _best_score_col = (
self->_score <= _best_score) ? self->database_it : _best_score_col;
399 _best_score = (self->_score <= _best_score) ?
self->_score : _best_score;
403 size_t back_coordinate_first() const noexcept
405 derived_t
const *
self = static_cast<derived_t const *>(
this);
407 size_t offset = std::ranges::empty(self->database) ? 0u : 1u;
408 return std::ranges::distance(std::ranges::begin(self->database), _best_score_col) +
offset;
416 template <
typename derived_t,
typename edit_traits>
417 class edit_distance_unbanded_score_matrix_policy :
423 static_assert(edit_traits::compute_score_matrix,
424 "This policy assumes that edit_traits::compute_score_matrix is true.");
432 edit_distance_unbanded_score_matrix_policy() noexcept = default;
433 edit_distance_unbanded_score_matrix_policy(edit_distance_unbanded_score_matrix_policy const &) noexcept
435 edit_distance_unbanded_score_matrix_policy(edit_distance_unbanded_score_matrix_policy &&) noexcept
437 edit_distance_unbanded_score_matrix_policy & operator=(edit_distance_unbanded_score_matrix_policy const &) noexcept
439 edit_distance_unbanded_score_matrix_policy & operator=(edit_distance_unbanded_score_matrix_policy &&) noexcept
441 ~edit_distance_unbanded_score_matrix_policy() noexcept = default;
444 using typename edit_traits::score_matrix_type;
450 score_matrix_type _score_matrix{};
466 void score_matrix_init()
468 derived_t
const *
self = static_cast<derived_t const *>(
this);
483 score_matrix_type
const & score_matrix() const noexcept
485 static_assert(edit_traits::compute_score_matrix,
"score_matrix() can only be computed if you specify the "
486 "result type within your alignment config.");
487 return _score_matrix;
495 template <
typename derived_t,
typename edit_traits>
496 class edit_distance_unbanded_trace_matrix_policy :
502 static_assert(edit_traits::compute_trace_matrix,
503 "This policy assumes that edit_traits::compute_trace_matrix is true.");
511 edit_distance_unbanded_trace_matrix_policy() noexcept = default;
512 edit_distance_unbanded_trace_matrix_policy(edit_distance_unbanded_trace_matrix_policy const &) noexcept
514 edit_distance_unbanded_trace_matrix_policy(edit_distance_unbanded_trace_matrix_policy &&) noexcept
516 edit_distance_unbanded_trace_matrix_policy & operator=(edit_distance_unbanded_trace_matrix_policy const &) noexcept
518 edit_distance_unbanded_trace_matrix_policy & operator=(edit_distance_unbanded_trace_matrix_policy &&) noexcept
520 ~edit_distance_unbanded_trace_matrix_policy() noexcept = default;
523 using typename edit_traits::word_type;
524 using typename edit_traits::trace_matrix_type;
525 using typename edit_traits::result_value_type;
531 std::vector<word_type> hp{};
537 trace_matrix_type _trace_matrix{};
552 void trace_matrix_init(
size_t block_count)
554 derived_t
const *
self = static_cast<derived_t const *>(
this);
559 hp.resize(block_count, 0u);
560 db.resize(block_count, 0u);
569 trace_matrix_type
const & trace_matrix() const noexcept
572 static_assert(edit_traits::compute_trace_matrix,
"trace_matrix() can only be computed if you specify the "
573 "result type within your alignment config.");
574 return _trace_matrix;
579 alignment_coordinate front_coordinate() const noexcept
581 derived_t
const *
self = static_cast<derived_t const *>(
this);
582 static_assert(edit_traits::compute_front_coordinate,
"front_coordinate() can only be computed if you specify "
583 "the result type within your alignment config.");
584 if (!self->is_valid())
585 return self->invalid_coordinate();
587 alignment_coordinate
const back =
self->back_coordinate();
588 return alignment_front_coordinate(trace_matrix(),
back);
595 using alignment_t = decltype(result_value_type{}.alignment);
597 derived_t
const *
self = static_cast<derived_t const *>(
this);
598 static_assert(edit_traits::compute_sequence_alignment,
"alignment() can only be computed if you specify the "
599 "result type within your alignment config.");
601 if (!self->is_valid())
602 return alignment_t{};
604 return alignment_trace<alignment_t>(self->database,
607 self->back_coordinate(),
616 template <
typename value_t>
617 class proxy_reference
623 proxy_reference() noexcept = default;
624 proxy_reference(proxy_reference const &) noexcept = default;
625 proxy_reference(proxy_reference &&) noexcept = default;
626 proxy_reference & operator=(proxy_reference const &) noexcept = default;
627 proxy_reference & operator=(proxy_reference &&) noexcept = default;
628 ~proxy_reference() = default;
631 proxy_reference(value_t & t) noexcept
632 : ptr(
std::addressof(t))
635 proxy_reference(value_t &&) =
delete;
638 template <
typename other_value_t>
642 proxy_reference & operator=(other_value_t && u) noexcept
644 get() = std::forward<other_value_t>(u);
650 value_t &
get() const noexcept
652 assert(ptr !=
nullptr);
657 operator value_t & ()
const noexcept
664 value_t * ptr{
nullptr};
678 template <std::ranges::viewable_range database_t,
679 std::ranges::viewable_range query_t,
680 typename align_config_t,
681 typename edit_traits>
682 class edit_distance_unbanded :
685 public edit_distance_base<
686 edit_traits::use_max_errors,
687 edit_distance_unbanded_max_errors_policy,
689 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
690 public edit_distance_base<
691 edit_traits::is_global,
692 edit_distance_unbanded_global_policy,
694 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
695 public edit_distance_base<
696 edit_traits::is_semi_global,
697 edit_distance_unbanded_semi_global_policy,
699 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
700 public edit_distance_base<
701 edit_traits::compute_score_matrix,
702 edit_distance_unbanded_score_matrix_policy,
704 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>,
705 public edit_distance_base<
706 edit_traits::compute_trace_matrix,
707 edit_distance_unbanded_trace_matrix_policy,
709 edit_distance_unbanded<database_t, query_t, align_config_t, edit_traits>>
713 using typename edit_traits::word_type;
714 using typename edit_traits::score_type;
715 using typename edit_traits::database_type;
716 using typename edit_traits::query_type;
717 using typename edit_traits::align_config_type;
718 using edit_traits::word_size;
722 template <
typename other_derived_t,
typename other_edit_traits>
723 friend class edit_distance_unbanded_max_errors_policy;
725 template <
typename other_derived_t,
typename other_edit_traits>
726 friend class edit_distance_unbanded_global_policy;
728 template <
typename other_derived_t,
typename other_edit_traits>
729 friend class edit_distance_unbanded_semi_global_policy;
731 template <
typename other_derived_t,
typename other_edit_traits>
732 friend class edit_distance_unbanded_score_matrix_policy;
734 template <
typename other_derived_t,
typename other_edit_traits>
735 friend class edit_distance_unbanded_trace_matrix_policy;
737 using typename edit_traits::database_iterator;
738 using typename edit_traits::query_alphabet_type;
739 using typename edit_traits::result_value_type;
740 using edit_traits::use_max_errors;
741 using edit_traits::is_semi_global;
742 using edit_traits::is_global;
743 using edit_traits::compute_score;
744 using edit_traits::compute_back_coordinate;
745 using edit_traits::compute_front_coordinate;
746 using edit_traits::compute_sequence_alignment;
747 using edit_traits::compute_score_matrix;
748 using edit_traits::compute_trace_matrix;
749 using edit_traits::compute_matrix;
750 using typename edit_traits::score_matrix_type;
751 using typename edit_traits::trace_matrix_type;
758 align_config_t config;
761 static constexpr word_type hp0 = is_global ? 1u : 0u;
763 static constexpr word_type hn0 = 0u;
765 static constexpr word_type vp0 = ~word_type{0u};
767 static constexpr word_type vn0 = 0u;
777 word_type score_mask{0u};
791 database_iterator database_it{};
793 database_iterator database_it_end{};
796 struct compute_state_trace_matrix
799 proxy_reference<word_type> db{};
803 struct compute_state : enable_state_t<compute_trace_matrix, compute_state_trace_matrix>
817 proxy_reference<word_type> vp{};
819 proxy_reference<word_type> vn{};
821 word_type carry_d0{};
823 word_type carry_hp{hp0};
825 word_type carry_hn{};
831 if constexpr(!use_max_errors && compute_score_matrix)
832 this->_score_matrix.add_column(vp, vn);
834 if constexpr(!use_max_errors && compute_trace_matrix)
835 this->_trace_matrix.add_column(this->hp, this->db, vp);
837 if constexpr(use_max_errors && compute_matrix)
839 size_t max_rows = this->max_rows(score_mask, this->last_block, _score, this->max_errors);
840 if constexpr(compute_score_matrix)
841 this->_score_matrix.add_column(vp, vn, max_rows);
843 if constexpr(compute_trace_matrix)
844 this->_trace_matrix.add_column(this->hp, this->db, vp, max_rows);
852 edit_distance_unbanded() =
delete;
854 edit_distance_unbanded(edit_distance_unbanded
const &) =
default;
855 edit_distance_unbanded(edit_distance_unbanded &&) =
default;
856 edit_distance_unbanded & operator=(edit_distance_unbanded
const &) =
default;
857 edit_distance_unbanded & operator=(edit_distance_unbanded &&) =
default;
858 ~edit_distance_unbanded() =
default;
866 edit_distance_unbanded(database_t _database,
868 align_config_t _config,
869 edit_traits
const & SEQAN3_DOXYGEN_ONLY(_traits) = edit_traits{}) :
870 database{std::forward<database_t>(_database)},
871 query{std::forward<query_t>(_query)},
872 config{std::forward<align_config_t>(_config)},
874 database_it{ranges::begin(database)},
875 database_it_end{ranges::end(database)}
877 static constexpr
size_t alphabet_size_ = alphabet_size<query_alphabet_type>;
879 size_t const block_count = (
std::ranges::size(query) - 1u + word_size) / word_size;
880 score_mask = word_type{1u} << ((
std::ranges::size(query) - 1u + word_size) % word_size);
883 if constexpr(use_max_errors)
884 this->max_errors_init(block_count);
886 if constexpr(compute_score_matrix)
887 this->score_matrix_init();
889 if constexpr(compute_trace_matrix)
890 this->trace_matrix_init(block_count);
892 vp.resize(block_count, vp0);
893 vn.resize(block_count, vn0);
894 bit_masks.resize((alphabet_size_ + 1u) * block_count, 0u);
899 size_t const i = block_count *
seqan3::to_rank(query[j]) + j / word_size;
900 bit_masks[i] |= word_type{1u} << (j % word_size);
909 template <
bool with_carry>
910 static void compute_step(compute_state & state) noexcept
913 assert(state.carry_d0 <= 1u);
914 assert(state.carry_hp <= 1u);
915 assert(state.carry_hn <= 1u);
917 x = state.b | state.vn;
918 t = state.vp + (x & state.vp) + state.carry_d0;
920 state.d0 = (t ^ state.vp) | x;
921 state.hn = state.vp & state.d0;
922 state.hp = state.vn | ~(state.vp | state.d0);
924 if constexpr(with_carry)
925 state.carry_d0 = (state.carry_d0 != 0u) ? t <= state.vp : t < state.vp;
927 x = (state.hp << 1u) | state.carry_hp;
928 state.vn = x & state.d0;
929 state.vp = (state.hn << 1u) | ~(x | state.d0) | state.carry_hn;
931 if constexpr(with_carry)
933 state.carry_hp = state.hp >> (word_size - 1u);
934 state.carry_hn = state.hn >> (word_size - 1u);
939 template <
bool with_carry>
940 void compute_kernel(compute_state & state,
size_t const block_offset,
size_t const current_block) noexcept
942 state.vp = proxy_reference<word_type>{this->vp[current_block]};
943 state.vn = proxy_reference<word_type>{this->vn[current_block]};
944 if constexpr(compute_trace_matrix)
946 state.hp = proxy_reference<word_type>{this->hp[current_block]};
947 state.db = proxy_reference<word_type>{this->db[current_block]};
949 state.b = bit_masks[block_offset + current_block];
951 compute_step<with_carry>(state);
952 if constexpr(compute_trace_matrix)
953 state.db = ~(state.b ^ state.d0);
957 void advance_score(word_type P, word_type N, word_type mask) noexcept
959 if ((P & mask) != word_type{0u})
961 else if ((N & mask) != word_type{0u})
966 bool on_hit() noexcept
973 inline bool small_patterns();
976 inline bool large_patterns();
979 inline void compute_empty_query_sequence()
981 assert(std::ranges::empty(query));
983 bool abort_computation =
false;
985 for (; database_it != database_it_end; ++database_it)
987 if constexpr(is_global)
990 this->update_best_score();
993 if constexpr(use_max_errors)
994 abort_computation = on_hit();
997 if (abort_computation)
1006 if constexpr(use_max_errors && is_global && !compute_matrix)
1026 if (vp.size() == 0u)
1027 compute_empty_query_sequence();
1028 else if (vp.size() == 1u)
1033 if constexpr(is_global)
1034 this->update_best_score();
1042 alignment_result<result_value_type> operator()(
size_t const idx)
1045 result_value_type res_vt{};
1047 if constexpr (compute_score)
1049 res_vt.score = this->score().value_or(matrix_inf<score_type>);
1052 if constexpr (compute_back_coordinate)
1054 res_vt.back_coordinate = this->back_coordinate();
1057 if constexpr (compute_front_coordinate)
1059 if (this->is_valid())
1060 res_vt.front_coordinate = alignment_front_coordinate(this->trace_matrix(), res_vt.back_coordinate);
1062 res_vt.front_coordinate = this->invalid_coordinate();
1065 if constexpr (compute_sequence_alignment)
1067 if (this->is_valid())
1069 using alignment_t = decltype(res_vt.alignment);
1070 res_vt.alignment = alignment_trace<alignment_t>(database,
1072 this->trace_matrix(),
1073 res_vt.back_coordinate,
1074 res_vt.front_coordinate);
1077 return alignment_result<result_value_type>{
std::move(res_vt)};
1081 template <
typename database_t,
typename query_t,
typename align_config_t,
typename traits_t>
1082 bool edit_distance_unbanded<database_t, query_t, align_config_t, traits_t>::small_patterns()
1084 bool abort_computation =
false;
1087 while (database_it != database_it_end)
1089 compute_state state{};
1090 size_t const block_offset =
seqan3::to_rank((query_alphabet_type) *database_it);
1092 compute_kernel<false>(state, block_offset, 0u);
1093 advance_score(state.hp, state.hn, score_mask);
1096 if constexpr(is_semi_global && !use_max_errors)
1097 this->update_best_score();
1100 if constexpr(use_max_errors)
1101 abort_computation = this->update_last_active_cell();
1105 if (abort_computation)
1112 template <
typename database_t,
typename query_t,
typename align_config_t,
typename traits_t>
1113 bool edit_distance_unbanded<database_t, query_t, align_config_t, traits_t>::large_patterns()
1115 bool abort_computation =
false;
1117 while (database_it != database_it_end)
1119 compute_state state{};
1120 size_t const block_offset = vp.size() *
seqan3::to_rank((query_alphabet_type) *database_it);
1122 size_t block_count = vp.size();
1123 if constexpr(use_max_errors)
1124 block_count = this->last_block + 1;
1127 for (
size_t current_block = 0u; current_block < block_count; current_block++)
1128 compute_kernel<true>(state, block_offset, current_block);
1130 advance_score(state.hp, state.hn, score_mask);
1133 if constexpr(is_semi_global && !use_max_errors)
1134 this->update_best_score();
1136 if constexpr(use_max_errors)
1139 bool additional_block = score_mask >> (word_size - 1u);
1140 bool reached_last_block = this->last_block + 1u == vp.size();
1142 if (reached_last_block)
1143 additional_block =
false;
1145 if (additional_block)
1147 size_t const current_block = this->last_block + 1u;
1149 vp[current_block] = vp0;
1150 vn[current_block] = vn0;
1151 compute_kernel<false>(state, block_offset, current_block);
1155 abort_computation = this->update_last_active_cell();
1161 if (abort_computation)
1173 template <
typename database_t,
typename query_t,
typename config_t>
1175 edit_distance_unbanded(database_t && database, query_t && query, config_t config)
1176 -> edit_distance_unbanded<database_t, query_t, config_t>;
1179 template <
typename database_t,
typename query_t,
typename config_t,
typename traits_t>
1180 edit_distance_unbanded(database_t && database, query_t && query, config_t config, traits_t)
1181 -> edit_distance_unbanded<database_t, query_t, config_t, traits_t>;