第一版是最简单的,直接backtrack。一列一列的尝试,每一列都会尝试所有的n个可能性(n行)。
//这个函数的时间复杂度是O(n)的。
static bool IsValid(int8_t col, int8_t row, const std::vector<int8_t>& chosen) {
for (int8_t chosen_col = 0; chosen_col != chosen.size(); ++chosen_col) {
int8_t chosen_row = chosen[chosen_col];
if (chosen_row == row) return false; // The row has been chosen
if (std::abs(chosen_row - row) == std::abs(chosen_col - col)) return false;
}
return true;
}
size_t CountQueenSolutions(int8_t n){
size_t result = 0;
int8_t row = 0;
std::vector<int8_t> chosen;
while (row < n || !chosen.empty()) {
if (row >= n) {
row = chosen.back() + 1;
chosen.pop_back();
continue;
}
if (IsValid(chosen.size(), row, chosen)) {
chosen.push_back(row);
if (chosen.size() != n) {
row = 0;
continue;
} else {
++result;
chosen.pop_back();
}
}
++row;
}
return result;
}
上面的代码每次尝试往一个格子里放棋子的时候,都需要扫描这个格子所在的行、列、对角线上的每一个格子。
于是我写了第二版,第二版是用位运算来加速检测。每次尝试只需要检测3个bit,而不是扫描四个数组。另外,如果是刷leetcode等,可以把下面的std::array<uint64_t,3> bit_masks 变量直接换成`uint64_t bitmasks`,因为一共需要大约5n个bits,如果n<13,那么一个整数就够了。那些online judge的test cases一般n都很小。
class NQueenState {
private:
std::array<uint64_t,3> bit_masks ={0,0,0};
std::vector<int8_t> chosen;
std::vector<std::array<uint64_t,3>> masks_history;
int8_t n;
public:
NQueenState(int8_t n1):n(n1){
chosen.reserve(n);
masks_history.reserve(n);
}
bool IsFinished() const{
return chosen.size() == n;
}
bool CanRollback() const{
return !chosen.empty();
}
const std::vector<int8_t>& GetResult() const {
return chosen;
}
bool Set(int8_t row){
size_t col = chosen.size();
size_t o1 = row + col;
size_t o2 = row - col + n - 1;
uint64_t m1 = static_cast<uint64_t>(1ULL) << row;
uint64_t m2 = static_cast<uint64_t>(1ULL) << o1;
uint64_t m3 = static_cast<uint64_t>(1ULL) << o2;
if ((bit_masks[0] & m1) !=0 || (bit_masks[1] & m2) !=0 || (bit_masks[2] & m3) !=0) return false;
masks_history.push_back(bit_masks);
bit_masks[0] |= m1;
bit_masks[1] |= m2;
bit_masks[2] |= m3;
chosen.push_back(row);
return true;
}
int8_t Rollback(){
int8_t row = chosen.back();
chosen.pop_back();
bit_masks = masks_history.back();
masks_history.pop_back();
++row;
return row;
}
};
template <typename Func>
static void SolveNQueenImpl(int8_t n, Func&& callback) {
int8_t row = 0;
NQueenState state(n);
while (true) {
while (row >= n && state.CanRollback()) {
row = state.Rollback();
continue;
}
if(row >=n)
break;
if (!state.Set(row)) {
++row;
continue;
}
if (!state.IsFinished()) {
row = 0;
continue;
} else {
callback(state.GetResult());
row = n;
}
}
}
size_t CountQueenSolutions2(int8_t n) {
size_t result = 0;
SolveNQueenImpl(n, [&](const std::vector<int8_t>&){
++result;
});
return result;
}
然后有了第三版,它采用了截然不同的思路:在每放置一个棋子之后,计算下一个格子的所有可能性。下面的代码需要比较高的位运算的技巧,我还没完全优化完。
#include <iostream>
#include <array>
#include <vector>
#include <assert.h>
#include "nqueens.h"
static inline uint32_t CreateSetU(int8_t n){
assert(n<sizeof(uint32_t)*8);
return (static_cast<uint32_t>(1) << n) -1;
}
static constexpr int8_t MAX_SET_SIZE = sizeof(uint32_t) * 8;
//return MAX_SET_SIZE if not found
static inline int8_t GetMinElement(uint32_t input){
uint32_t mask = 1;
int8_t ret = 0;
while((input & mask) ==0 && ++ret != MAX_SET_SIZE) {
mask <<=1;
}
return ret;
}
size_t CountQueenSolutions3(int8_t n){
size_t count = 0;
std::vector<uint32_t> search_state;
search_state.reserve(n);
std::vector<int8_t> chosen;
search_state.push_back(CreateSetU(n));
std::array<uint32_t,3> aux_sets ={0,0,0};
std::vector<std::array<uint32_t,3>> aux_sets_history;
aux_sets_history.reserve(n);
while(true){
uint32_t S = search_state.back();
if(S == 0){
if(chosen.empty())
break;
int8_t last_chosen = chosen.back();
chosen.pop_back();
aux_sets = aux_sets_history.back();
aux_sets_history.pop_back();
search_state.pop_back();
//remove one
search_state.back() &= ~(1<<last_chosen);
continue;
}
int8_t current_x = GetMinElement(S);
chosen.push_back(current_x);
if(chosen.size() == n){
++count;
search_state.push_back(0);
aux_sets_history.push_back(aux_sets);
continue;
}
aux_sets_history.push_back(aux_sets);
aux_sets[0] |= (static_cast<uint32_t>(1)<<current_x);
aux_sets[1] = (aux_sets[1] >>1) | (static_cast<uint32_t>(1)<<(n+current_x-1));
aux_sets[2] = (aux_sets[2] << 1 ) |(static_cast<uint32_t>(1)<<(current_x+1));
uint32_t next = CreateSetU(n);
next = next & (~aux_sets[0]) & (~(aux_sets[1]>>n)) & (~aux_sets[2]);
search_state.push_back(next);
}
return count;
}
--
FROM 107.139.34.*