#include <iostream>       // std::cout
#include <atomic>         // std::atomic
#include <thread>         // std::thread
#include <vector>         // std::vector
#include <cassert>

// simulates thd->wsrep_aborter
std::atomic<std::thread::id> wsrep_aborter;
int i=0;

// thd::reset_killed
void reset(void)
{
	wsrep_aborter=(std::thread::id)0;
}

void kill(std::thread::id thread_id)
{
	// sql_parse.cc : kill_one_thread()
	// if (victim_thd->wsrep_aborter &&
	//     victim_thd->wsrep_aborter != bf_thd->thread_id) out;

	std::thread::id free = (std::thread::id)0, writer= thread_id;

	std::cout << "KILL: Thread trying is : " << thread_id << " wsrep_aborter= " << wsrep_aborter << "\n";

	if (!wsrep_aborter.compare_exchange_strong(free, writer,
			std::memory_order_acquire,
                        std::memory_order_relaxed) &&
	    !(free==writer))
	{
		std::cout << "KILL: Bail out as wsrep_aborter= " << wsrep_aborter << "\n";
	} else {
		i++;
		assert(i == 1);
		std::cout << "KILL: Continue kill as wsrep_aborter= " << wsrep_aborter << "\n";
		std::cout << "KILL: Kil done...reset..." << "\n";
		i--;
		reset();
	}

}

void bf_kill(std::thread::id thread_id)
{
	// ha_innodb.cc wsrep_innobase_kill_one_trx()
	// 	if (wsrep_thd_set_wsrep_aborter(bf_thd, victim_thd))
        // ==   if (victim_thd->wsrep_aborter &&
	// 	    victim_thd->wsrep_aborter != bf_thd->thread_id) out;

	std::thread::id free= (std::thread::id)0, writer = thread_id;

	std::cout << "BF_KILL: Thread trying is : " << thread_id << " wsrep_aborter= " << wsrep_aborter << "\n";

	if (!wsrep_aborter.compare_exchange_strong(free, writer,
			std::memory_order_acquire,
                        std::memory_order_relaxed) &&
	    !(free==writer))
	{
		std::cout << "BF_KILL: Bail out as wsrep_aborter= " << wsrep_aborter << "\n";
	} else {
		i++;
		assert(i == 1);
		std::cout << "BF_KILL: Continue kill now wsrep_aborter= " << wsrep_aborter << "\n";
		std::cout << "BF_KILL: Killed...reseting..." << "\n";
		i--;
                reset();
	}
}

void driver(int i)
{
	std::thread::id this_id = std::this_thread::get_id();
	std::cout << "Driver for " << i << "THREAD: " << this_id << "\n";
	for(int i=0; i< 10000; i++)
	  kill(this_id);
}
void driver2(int i)
{
	std::thread::id this_id = std::this_thread::get_id();
	std::cout << "Driver for " << i << "THREAD: " << this_id << "\n";
	for(int i=0; i< 10000; i++)
	  bf_kill(this_id);
}

int main(int argc, char**argv)
{
  std::vector<std::thread> threads;
  for (int i=0; i<16; ++i) threads.push_back(std::thread(i % 2 ? driver: driver2,i));
  for (int i=0; i<16; ++i) threads.push_back(std::thread(i % 2 ? driver2: driver,i));
  for (auto& th : threads) th.join();
}
