import mysql.connector as mariadb
import _thread as thread
import time as Time
import sys as Sys

try:
    node1_conn1= mariadb.connect(user='usr', password='root' , database= 'd',
            unix_socket='/tmp/mysql.sock1', autocommit= True)
    node1_conn2= mariadb.connect(user='usr', password='root' , database= 'd',
            unix_socket='/tmp/mysql.sock1', autocommit= True)
    node1_conn3= mariadb.connect(user='usr', password='root' , database= 'd',
            unix_socket='/tmp/mysql.sock1', autocommit = True)
    node2_conn1 = mariadb.connect(user='usr', password='root' , database= 'd',
            unix_socket='/tmp/mysql.sock2', autocommit = True)
    node2_conn2= mariadb.connect(user='usr', password='root' , database= 'd',
            unix_socket='/tmp/mysql.sock2', autocommit = True)

except Exception as e:
    print (e.args)

def init():
    print ('Initied')
    try:
        #node1 alter  table and simultaneous insert
        node1_conn1._execute_query('drop table if exists t1 ')
        node1_conn1._execute_query('create table t1 (a1 int , a2 int as (a1 +1 '\
                ') virtual , a3 int as (a1 + 2) persistent )')
        insert_query = 'INSERT INTO t1(a1) VALUES (1)'
        node1_conn1._execute_query(insert_query)
    except Exception as e:
        print (e.args)


#thread 1 will continuously alter while thread 2 will insert on same node
#thread 3 will query data if there is any disparity in nodes data

def alter_table(connection):
    count = 4
    # first we wait for some time to node to get some data
    Time.sleep(20.00)
    wait_time = 1.00
    cursor = connection.cursor()
    total_rows= 1;
    col_type=''
    while 1:
        try:
            Time.sleep(wait_time)
            cursor.execute('SELECT COUNT(*) FROM t1')
            total_rows=cursor.fetchone()[0]
            if count % 2 == 0:
                col_type = 'PERSISTENT'
            else:
                col_type = 'VIRTUAL'
            alter_query = 'ALTER TABLE t1 ADD COLUMN a'+str(count)+ ' INT as' \
                        +'( a1 + '+str(count - 1)+ ' ) '+col_type
            print (alter_query)
            cursor.execute(alter_query)
            #Increase 1% everytime
            wait_time += (wait_time / 100)
            count +=1
        except Exception as e:
            print('Exception in ALTER')
            print(e.args)

def insert_table(connection ):
    insert_query = 'INSERT INTO t1(a1) VALUES (1)'
    cursor= connection.cursor()
    while 1:
        try:
            cursor.execute(insert_query)
        except Exception as e:
            print('Exception in INSERT')
            print(e.args)

def cmp_result (conn1, conn2):
    Time.sleep(20)
    select_query_1 = 'SELECT * from t1 limit 1'
    select_query_2 = 'SELECT * from t1 limit 100'
    cursor1= conn1.cursor(buffered=True)
    cursor2= conn2.cursor(buffered=True )
    while 1:
        try:
            cursor1.execute(select_query_1)
            cursor2.execute(select_query_2)
            row_cmp(cursor1.fetchone(), cursor2.fetchall()[99])
        except Exception as e:
            print('Exception in SELECT ')
            print(e.args)

def row_cmp(row_node1, row_node2):
    len1= len(row_node1)
    if len(row_node1 ) != len(row_node2):
        print("Issue  No of columns NOT   Equal")
        if len(row_node1) == len(row_node2) +1 :
            if row_node1[len1 - 1 ] == len1 :
                Time.sleep(1)
                print ('No issues')
            else:
                print (row_node1[len1 - 1 ])
                Sys.exit("Error found")
        else:
            Sys.exit("Error found")
        len1= len(row_node2)
    print ('No of columns = '+ str(len1))
    while len1:
        if row_node1[ len1-1] != len1 or row_node2[len1 -1] != len1 :
           print("Error found value not equal ")
           Sys.exit("Error found")
        len1 -= 1
    Time.sleep(1)


def num_of_rows (connection):
    Time.sleep(20)
    select_query= 'select count(*) from t1';
    cursor = connection.cursor(buffered = True)
    while 1:
        try:
            cursor.execute(select_query)
            print ('No of Rows' + str(cursor.fetchone()[0]))
        except Exception as e:
            print('Exception In Total no of Rows')
            print(e.args)

try:
    init()
    thread.start_new_thread(alter_table, (node1_conn1,))
    #thread.start_new_thread(insert_table,(node1_conn2, ))
    #thread.start_new_thread(cmp_result,(node1_conn3,node2_conn1, ))
    #thread.start_new_thread(num_of_rows, (node2_conn2,))
    Time.sleep(1000000)
except Exception as e:
    print ('Something Bad in threads happened')
    print (e.args)
