#!/usr/bin/python3

import sys
import time
import mariadb

db_user = 'root'
db_pass = 'rick40'
db_sock = '/var/lib/mysql/mysql.sock'

# =========================================================
#  mariadb.Connection wrapper to fix with memory leak issue
# =========================================================
class Connection( mariadb.Connection ):
    """
    Support for with statement context management
    """
    def __init__( self, **kwargs ):
        self._cursor = None
        return super().__init__( **kwargs )

    def __enter__( self ):
        return self

    def __exit__( self, *exc ):
        # close the connection in case it was open (will also destroy cursor)
        try:
            if self._cursor is not None:
                self._cursor.close()
            self.close()
        except:
            pass

    def cursor( self, **kwargs ):
        """ save cursor for destruction on when exit """
        self._cursor = super().cursor( **kwargs )
        return self._cursor

# -----------------
# rest of test code
# -----------------
def log_sql( sql_statement, err, e=None ):
    """
    Log SQL error to stderr
    """
    sql = sql_statement.replace('"',"'")

    # mask any PASSWORD strings (should only be at most one)
    ppos = sql.find('PASSWORD')
    if ppos >= 0:
        # split out first string after PASSWORD and mask
        parts = sql[ppos:-1].split( sep="'", maxsplit=3 )
        if len(parts) > 1:
            sql = sql.replace( parts[1], '*****' )

    sys.stderr.write("SQL [%s] %s%s\n" % (sql, err, ((': ' + str(e)) if e else '')))

def last_error( cur, deferr=0 ):
    """
    Get last error code that occurred on last failed SQL execution.
    Returns last error from DB (as negative value) or deferr if not
    last error is logged.
    """
    try:
        cur.execute('SHOW ERRORS')
        rc = cur.rowcount
    except Exception as e:
        log_sql( 'SHOW ERRORS', 'FAILED', e )
        return deferr

    # if no error returned, return default error code
    if rc == 0:
        return deferr

    # fetch last SQL error
    error = cur.fetchone()
    return -int(error['Code'])

def execute_sql( cur, sql, seconds=0 ):
    """
    Execute SQL statement will optional timeout and return return code:
      <0 = [failure] mariadb error code which occurred on exception
       0 = [success] sql return a 0 result
      >0 = [success] number of rows returned in query
    ADD max_statement_time for this query if a timeout is given
    """
    sql_statement = ''
    if seconds > 0:
        sql_statement = 'SET STATEMENT max_statement_time=%f FOR ' % seconds
        sql_statement += sql

    """
    Perform the SQL statement
    """
    try:
        cur.execute( sql_statement )
        nrows = cur.rowcount
    except (mariadb.Error, mariadb.OperationalError) as e:
        # NOTE:  MySQLdb typically returns Error
        #        mariadb typically return OperationalError
        # The error code should be stored in the database error table.
        # Do not print error on 1617 error since should not be an error
        rc = self.last_error( -1 )
        if rc != -DBerror.WARN_NO_MASTER_INF:
            log_sql( sql_statement, 'Error Code %d' % abs(rc), e )
        return rc
    except mariadb.ProgammingError as e:
        log_sql( sql_statement, 'Invalid Syntax', e )
        return -1
    except Exception as e:
        # other exception
        log_sql( sql_statement, 'Unhandled Exception', e )
        return -1

    # return number of rows returned from SQL query
    return nrows

def do_it():

    db_conn = None

    try:
        # use Connection class with memory leak fix
        with Connection( host='localhost', user=db_user, password=db_pass, unix_socket=db_sock, connect_timeout=5 ) as db_conn:
            my_cur = db_conn.cursor( dictionary=True, buffered=True )

            # perform show databases
            if execute_sql( my_cur, "SHOW DATABASES;", 5 ) <= 0:
                return (1, 'SHOW DATABASES FAILED')

            # Check replication status
            rc = execute_sql( my_cur, "SHOW SLAVE STATUS;", 10 )
            if rc < 0:
                return (2, 'SHOw SLAVE STATus FAILED')
            elif rc > 0:
                status = my_cur.fetchone()

    except (mariadb.Error, mariadb.OperationalError) as ex:
        return (3, str(ex))
    except Exception as ex:
        return (4, str(ex))
    return (0, None)

def main():

    waittime = 2.0
    while True:
        ret, msg = do_it()
        if ret != 0:
            sys.stderr.write('RC=%d: %s\n' % (ret, msg))
        time.sleep(waittime)

if __name__ == '__main__':
    main()

