from os import environ
from mariadb import ConnectionPool
from time import sleep

host = environ.get("DB_HOST", "localhost")
port = int(environ.get("DB_PORT", "3306"))
user = environ.get("DB_USER", "root")
password = environ["DB_PASSWORD"]
database = environ.get("DB_NAME", "")

POOL_SIZE = 3
WAIT_TIMEOUT = 10  # seconds

pool = ConnectionPool(
    host=host,
    port=port,
    user=user,
    password=password,
    database=database,
    pool_name="Test duplicate connection",
    pool_size=POOL_SIZE,
)

sleep(WAIT_TIMEOUT + 1)

connections = set()
for i in range(POOL_SIZE):
    connection = pool.get_connection()
    if (connection in connections):
        print(f"Duplicate connection: {connection}")
        exit(1)
    connections.add(connection)
