#!/usr/bin/env python3

import sys
try:
    import pymysql
except:
    print('pymysql needs to be installed')
    sys.exit(1)

try:
    import threading
except:
    print('threading needs to be installed')
    sys.exit(1)

def threaded_test(cur, res, worker):
    print(f'try to delete all rows at once locked by {worker}')
    try:
        cur.execute("delete from jobs where id in %(ids)s", {'ids': [i[0] for i in res]})
        print(f'{worker}: OK')
    except Exception as e:
        print(f"{worker}: {e}")

tx_iso = 'tx_isolation' # mariadb 10.6
#tx_iso = 'transaction_isolation' # mysql 8.0
level = 'READ-COMMITTED'
lock_w_timeout = 2
cleanup = True

print('Config:')
print(f"transaction isolation: {tx_iso}")
print(f"level: {level}")
print(f"innodb_lock_wait_timout: {lock_w_timeout}")
print(f"cleanup after finished: {cleanup}")

#master_c = pymysql.connect(host='localhost', read_default_file='/home/rex/.my.cnf', database='test')
#worker1_c = pymysql.connect(host='localhost', read_default_file='/home/rex/.my.cnf', database='test')
#worker2_c = pymysql.connect(host='localhost', read_default_file='/home/rex/.my.cnf', database='test')
master_c = pymysql.connect(unix_socket='/tmp/mariadbd.sock', database='test')
worker1_c = pymysql.connect(unix_socket='/tmp/mariadbd.sock', database='test')
worker2_c = pymysql.connect(unix_socket='/tmp/mariadbd.sock', database='test')

m = master_c.cursor()
w1 = worker1_c.cursor()
w2 = worker2_c.cursor()

# m.execute(f"set {tx_iso} = '{level}', innodb_lock_wait_timeout = %s", (2, ))
w1.execute(f"set {tx_iso} = '{level}', innodb_lock_wait_timeout = %s", (20, ))
w2.execute(f"set {tx_iso} = '{level}', innodb_lock_wait_timeout = %s", (20, ))

m.execute(f"select @@{tx_iso}, @@autocommit, @@innodb_lock_wait_timeout")
w1.execute(f"select @@{tx_iso}, @@autocommit, @@innodb_lock_wait_timeout")
w2.execute(f"select @@{tx_iso}, @@autocommit, @@innodb_lock_wait_timeout")

print('#########################################')
print()
print(f"{tx_iso}, autocommit, innodb_lock_wait_timeout")
print(f"master {m.fetchall()}")
print(f"worker {w1.fetchall()}")
print(f"worker {w2.fetchall()}")
print('#########################################')

print()
print('create table if not already there')
m.execute("""
drop table if exists jobs
""")
m.execute("""
CREATE TABLE IF NOT EXISTS jobs (
  id int(11) NOT NULL AUTO_INCREMENT,
  state varchar(25) DEFAULT NULL,
  created timestamp NOT NULL DEFAULT current_timestamp(),
  PRIMARY KEY (id)
)
""")

master_c.commit()

# create 5 test rows
for i in range(1, 6):
	m.execute("insert into jobs (state) values('pending')");

master_c.commit()

w1.execute(f"set sql_safe_updates = 1")
w2.execute(f"set sql_safe_updates = 1")

# lock rows with id 1 - 5 from worker 1
w1.execute("select id from jobs for update skip locked");
w1_res = w1.fetchall()

# create 5 test rows
for i in range(1, 6):
	m.execute("insert into jobs (state) values('pending')");

master_c.commit()

# lock rows with id 6 - 10 from worker 2
w2.execute("select id from jobs for update skip locked");
w2_res = w2.fetchall()

print()
print('rows locked by worker')
print(f"worker1 (id: 1-5)  : {w1_res}")
print(f"worker2 (id: 6-10) : {w2_res}")

"""
print()
print('try to delete all rows at once locked by worker 1')
print('expecting this to fail with lock wait timeout exceeded')
try:
    print(f"delete from jobs where id in (", [i[0] for i in w1_res], ")")
    w1.execute("delete from jobs where id in %(ids)s", {'ids': [i[0] for i in w1_res]})
    print('OK')
except pymysql.err.InternalError as e:
    print(e)

print()
print('try to delete rows 1,2,3,4 at once locked by worker 1')
print('expecting this to fail with lock wait timeout exceeded')
try:
    print("delete from jobs where id in %(ids)s", {'ids': [1,2,3,4]})
    w1.execute("delete from jobs where id in %(ids)s", {'ids': [1,2,3,4]})
    print('OK')
except pymysql.err.InternalError as e:
    print(e)
"""

print()
print('try to delete rows 1,2,3 at once locked by worker 1')
print('expecting this to succeed without failing')
try:
    print("delete from jobs where id in %(ids)s", {'ids': [1,2,3]})
    w1.execute("delete from jobs where id in %(ids)s", {'ids': [1,2,3]})
    print('OK')
except pymysql.err.InternalError as e:
    print(e)

print()
print('expecting this to succeed without failing')
print('delete all rows locked by worker 1 one by one')
for row in w1_res:
    try:
        print(f"delete row {row[0]}")
        w1.execute("delete from jobs where id = %s", row[0])
        print('OK')
    except pymysql.err.InternalError as e:
        print(e)

print()
print('try to delete all rows at once locked by worker 2')
print('expecting this to fail with lock wait timeout exceeded')
try:
    print(f"delete from jobs where id in %(ids)s", {'ids': [i[0] for i in w2_res]})
    w2.execute(f"delete from jobs where id in %(ids)s", {'ids': [i[0] for i in w2_res]})
#    w2.execute(f"delete from jobs where exists (select id from jobs where id in %(ids)s )", {'ids': [i[0] for i in w2_res]})
    print('OK')
except pymysql.err.InternalError as e:
    print(e)

print()
print('#########################################')
print('doing some parallel tests')

print()
print('rolling back worker 1')
worker1_c.rollback()

print()
print('resetting worker 1')

# lock rows with id 1 - 5 from worker 1
w1.execute("select id from jobs for update skip locked");
w1_res = w1.fetchall()

print()
print('rolling back worker 2')
worker2_c.rollback()

print()
print('resetting worker 2')

# lock rows with id 6 - 10 from worker 2
w2.execute("select id from jobs for update skip locked");
w2_res = w2.fetchall()

print()
print('rows locked by worker')
print(f"worker1 (id: 1-5)  : {w1_res}")
print(f"worker2 (id: 6-10) : {w2_res}")

t_worker1 = threading.Thread(target=threaded_test, args=(w1, w1_res, 'Worker 1'))
t_worker2 = threading.Thread(target=threaded_test, args=(w2, w2_res, 'Worker 2'))
threads = [t_worker1, t_worker2]

print()
print('start threads.. worker 2 will fail with deadlock detected')
for t in threads:
    t.start()

for t in threads:
    t.join()

if cleanup:
    print()
    print('cleaning up')
    worker1_c.commit()
    worker2_c.commit()
    
    m.execute("truncate jobs")
    master_c.commit()

master_c.close()
worker1_c.close()
worker2_c.close()
