package org.mariadb.jdbc;

import static org.junit.Assert.*;

import java.io.UnsupportedEncodingException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

public class MaxAllowedPacketTest {

    private int previousMaxAllowedPacket;

    @Before
    public void setUp() throws Exception {
        previousMaxAllowedPacket = getMaxAllowedPacket();
    }
    
    @After
    public void tearDown() throws SQLException {
        setMaxAllowedPacket(previousMaxAllowedPacket);
    }
    
    private Connection getConnection() throws SQLException {
        return DriverManager.getConnection("jdbc:mariadb://localhost:3306/test?user=root");
    }
    
    private void setMaxAllowedPacket(int maxAllowedPacket) throws SQLException {
        Connection connection = getConnection();
        Statement statement = connection.createStatement();
        statement.execute("set GLOBAL max_allowed_packet = " + maxAllowedPacket);
        statement.close();
        connection.close();
    }
    
    private int getMaxAllowedPacket() throws SQLException {
        int maxAllowedPacket = -1;
        
        Connection connection = getConnection();
        Statement statement = connection.createStatement();
        ResultSet rs = statement.executeQuery("select @@max_allowed_packet");
        if (rs.next()) {
            maxAllowedPacket = rs.getInt(1);
        }
        
        rs.close();
        statement.close();
        connection.close();
        
        return maxAllowedPacket;
    }
    
    @Test
    public void maxAllowedPackedExceptionIsPrettyTest() throws SQLException, UnsupportedEncodingException {
        
        int maxAllowedPacket = 1024 * 1024;
        
        setMaxAllowedPacket(maxAllowedPacket);
        
        String rowData = "(null, 'this is a dummy row values')";
        StringBuilder sb = new StringBuilder();
        
        //Create a SQL packet bigger than maxAllowedPacket
        int rowsToWrite = (maxAllowedPacket / rowData.getBytes("UTF-8").length) + 1;
        for (int row = 1;  row <= rowsToWrite; row++) {
            if (row >= 2) {
                sb.append(", ");
            }

            sb.append(rowData);
        }
        
        String sql = "INSERT INTO dummy_table_that_does_not_exist (id, name) VALUES " + sb.toString();
        
        Connection connection = getConnection();
        Statement statement = connection.createStatement();

        try {
            statement.executeUpdate(sql); //Will throw Exception because of packet size too large
            fail("The previous statement should throw an SQLException");
        } catch (SQLException e) {
            System.out.println(e.getMessage());
            //Make sure the error message i "pretty"
            assertTrue(e.getMessage().contains("max_allowed_packet"));
        }
        catch (Exception e) {
            fail("The previous statement should throw an SQLException not a general Exception");
        }
        finally {
            connection.close();    
        }
    }

}
