angle-uparrow-clockwisearrow-counterclockwisearrow-down-uparrow-leftatcalendarcard-listchatcheckenvelopefolderhouseinfo-circlepencilpeoplepersonperson-fillperson-plusphoneplusquestion-circlesearchtagtrashx

Multiprocessing, file locking, SQLite and testing

Testing for concurrency problems is harder and takes more time, but you can't do without it.

30 March 2023 Updated 30 March 2023
post main image

I was working on a project with SQLAlchemy and PostgreSQL. For a few tables, I wanted to limit the number of rows per user, and did this by adding a PostgreSQL check function and trigger.

Manual testing every thing appeared to be working fine but what if a user would start multiple processes and add rows at exactly the same time? I added the 'pg_advisory_xact_lock' but will this really work? Did I really understand the documentation?

In this post I show a universal TaskRunner class that can be used for testing simultaneous (concurrent) actions. As a test case, we use a SQLite database that we write with separate processes.

We start all processes from a single process. In this case, we can use Multiprocessing.Lock() to control access to SQLite. But I also implemented a file locker that can be used when we have fully independent processes.

As always I am running this on Ubuntu 22.04.

Starting actions at the same time

In our test setup we use Multiprocessing.Event() to make all processes wait at the same line in the task code, one line before the 'critical action'. Then, when all processes have reached this point, we 'release' the processes and see what happens.

                         stop & release
                              |
                              v

task1  |--------------------->|-------->
task2      |----------------->|-------->
task3        |--------------->|-------->
                              |
taskN               |-------->|-------->

        --------------------------------------> t

In the TaskRunner class:

class TaskRunner:
    ...
    def run_parallel_tasks(self, parallel_tasks_count):
        ...
        self.mp_event = multiprocessing.Event()
        ...
        for task_no in range(parallel_tasks_count):
            p = multiprocessing.Process(target=self.func_task, args=(self, task_no))
            ...

        # release waiting all processes
        time.sleep(self.release_time)
        self.mp_event.set()
        ...

In our task function:

def task(task_runner, task_no):
    ...
    # all tasks will wait here 
    task_runner.mp_event.wait()

    # critical action
    ...

Incrementing a SQLite table field

In our test, the tasks (processes) simultaneously try to increment a SQLite table field, 'counter',
by:

  • reading the field value
  • incrementing it
  • updating the field

If we have 100 tasks, then the result in the table field must be 100. Any other value is wrong.

Locking

A task cannot reliably perform the increment operation without gaining exclusive access to SQLite. Here, we use a lock external from SQLite.

We can distinguish the following:

  1. The (concurrent) tasks are started by a single process
  2. The (concurrent) tasks are independent

In the first case, we can use Multiprocessing.Lock() and share this lock between all our tasks. For testing purposes this is fine.

The second case is a more real world scenario. We cannot use Multiprocessing.Lock() here but we can use Linux file locking. This is fast and reliable.

Locking - Multiprocessing.Lock()

I want to use Multiprocessing.Lock() as a context manager. Unfortunately, we then cannot specify a timeout. This means we must write the context manager ourselves:

# multiprocessing locker context manager with timeout
class mp_locker:
    def __init__(
        self,
        mp_lock=None,
        timeout=10,
    ):
        self.mp_lock = mp_lock
        self.timeout = timeout

    def __enter__(self):
        self.mp_lock.acquire(timeout=self.timeout)

    def __exit__(self, exc_type, exc_value, exc_tb):
        self.mp_lock.release()

Locking - File locking

There are many examples on the internet on how to do this. Again I want to use this as a context manager. Here I only show the '__enter__()' method.

# file locker context manager
    ...
    def __enter__(self):
        while True:
            if (time.time() - ts) > self.timeout:
                raise Exception('pid = {}: acquire lock timeout')
            try:
                self.lock_file_fo = open(self.lock_file, 'a')
                fcntl.flock(self.lock_file_fo, fcntl.LOCK_EX | fcntl.LOCK_NB)
                break
            except BlockingIOError as e:
                # another process locked the file, keep trying
                time.sleep(self.wait_secs)
            # propagate other exceptions

We stay in the 'while-loop' until we acquire the lock or a timeout occurs.

The TaskRunner class

The TaskRunner contains all logic to start multiple tasks (processes).

Functions:

  • before_tasks()
  • task()
  • after_tasks()
  • result_ok()
  • after_result()

Options:

  • Number of concurrent tasks.
  • Number of times to repeat.
  • Waiting tasks release-time (after start).
  • Logging level.
  • Multiprocessing.Lock() locking, or file locking
  • Lock timeout.

Important: All your functions are called with the TaskRunner object as the first parameter. This means you have access to TaskRunner attributes and methods like:

  • get_lock()
  • get_logger()

The code

The code consists of the following parts:

  • TaskRunner class and support classes
  • Your task functions
  • TaskRunner instatiation with your parameters

When you run the code, The output is something like:

INFO     counter = 100 <- final value
INFO     ready in 2.0454471111297607 seconds

Here is the code in case you want to try yourself:

import fcntl
import logging
import multiprocessing
import os
import sys
import time

import sqlite3


class DummyLogger:
    def __getattr__(self, name):
        return lambda *args, **kwargs: None

# file locker context manager
class f_locker:
    def __init__(
        self,
        lock_file=None,
        timeout=10,
        logger=DummyLogger(),
        wait_secs=.01,
    ):
        self.lock_file = lock_file
        self.timeout = timeout
        self.logger = logger
        self.wait_secs = wait_secs
        # keep lock_file opened
        self.lock_file_fo = None

    def __enter__(self):
        pid = os.getpid()
        ts = time.time()
        while True:
            self.logger.debug('pid = {}: trying to acquire lock ...'.format(pid))
            if (time.time() - ts) > self.timeout:
                raise Exception('pid = {}: acquire lock timeout')
            # keep trying until lock or timeout
            try:
                self.lock_file_fo = open(self.lock_file, 'a')
                fcntl.flock(self.lock_file_fo, fcntl.LOCK_EX | fcntl.LOCK_NB)
                self.logger.debug('pid = {}: lock acquired'.format(pid))
                break
            except BlockingIOError as e:
                # another process locked the file, keep trying
                self.logger.debug('pid = {}: cannot acquire lock'.format(pid))
                time.sleep(self.wait_secs)
            # propagate other exceptions
        return True

    def __exit__(self, exc_type, exc_value, exc_tb):
        self.logger.debug('exc_type = {}, exc_value = {}, exc_tb = {}'.format(exc_type, exc_value, exc_tb))
        pid = os.getpid()
        self.logger.debug('pid = {}: trying to release lock ...'.format(pid))
        fcntl.flock(self.lock_file_fo, fcntl.LOCK_UN)
        self.logger.debug('pid = {}: lock released ...'.format(pid))


# multiprocessing locker context manager with timeout
class mp_locker:
    def __init__(
        self,
        mp_lock=None,
        timeout=10,
        logger=DummyLogger(),
    ):
        self.mp_lock = mp_lock
        self.timeout = timeout
        self.logger = logger

    def __enter__(self):
        self.pid = os.getpid()
        self.logger.debug('pid = {}: trying to acquire lock ...'.format(self.pid))
        self.mp_lock.acquire(timeout=self.timeout)
        self.logger.debug('pid = {}: lock acquired'.format(self.pid))

    def __exit__(self, exc_type, exc_value, exc_tb):
        self.logger.debug('exc_type = {}, exc_value = {}, exc_tb = {}'.format(exc_type, exc_value, exc_tb))
        self.logger.debug('pid = {}: trying to release lock ...'.format(self.pid))
        self.mp_lock.release()
        self.logger.debug('pid = {}: lock released ...'.format(self.pid))


class TaskRunner:
    def __init__(
        self,
        loop_count=1,
        parallel_tasks_count=1,
        release_time=1.,
        # functions
        func_before_tasks=None,
        func_task=None,
        func_after_tasks=None,
        func_result_ok=None,
        func_after_result=None,
        # logging
        logger_level=logging.DEBUG,
        # locking
        lock_timeout=10,
        use_file_locking=False,
        lock_file='./lock_file',
        lock_wait_secs=.01,
    ):
        self.loop_count = loop_count
        self.parallel_tasks_count = parallel_tasks_count
        self.release_time = release_time
        # functions
        self.func_before_tasks = func_before_tasks
        self.func_task = func_task
        self.func_after_tasks = func_after_tasks
        self.func_result_ok = func_result_ok
        self.func_after_result = func_after_result
        # logging
        self.logger_level = logger_level
        # locking
        self.lock_timeout = lock_timeout
        self.use_file_locking = use_file_locking
        self.lock_file = lock_file
        self.lock_wait_secs = lock_wait_secs

    def get_logger(self, proc_name, logger_level=None):
        if logger_level is None:
            logger_level = self.logger_level
        logger = logging.getLogger(proc_name)
        logger.setLevel(logging.DEBUG)
        console_handler = logging.StreamHandler()
        console_logger_format = '%(asctime)s %(proc_name)-8.8s %(levelname)-8.8s [%(filename)-20s%(funcName)20s():%(lineno)03s] %(message)s'
        console_handler.setFormatter(logging.Formatter(console_logger_format))
        logger.setLevel(logger_level)
        logger.addHandler(console_handler)
        logger = logging.LoggerAdapter(logger, {'proc_name': proc_name})
        return logger

    def get_lock(self, timeout=None):
        timeout = timeout or self.lock_timeout
        if not self.use_file_locking:
            return mp_locker(self.mp_lock, timeout=timeout, logger=self.logger)
        return f_locker(self.lock_file, timeout=timeout, wait_secs=self.lock_wait_secs)

    def run_parallel_tasks(self, parallel_tasks_count):
        # before tasks
        if self.func_before_tasks:
            self.func_before_tasks(self)

        self.mp_lock = multiprocessing.Lock()
        self.mp_event = multiprocessing.Event()
        tasks = []
        for task_no in range(parallel_tasks_count):
            p = multiprocessing.Process(target=self.func_task, args=(self, task_no))
            p.start()
            tasks.append(p)

        # release waiting processes
        time.sleep(self.release_time)
        self.mp_event.set()

        # wait for all tasks to complete
        for p in tasks:
            p.join()

        # after tasks
        if self.func_after_tasks:
            return self.func_after_tasks(self)
        return None

    def run(
        self, 
        loop_count=None,
        parallel_tasks_count=None,
    ):
        self.logger = self.get_logger('main')
        if loop_count is not None:
            self.loop_count = loop_count
        if parallel_tasks_count is not None:
            self.parallel_tasks_count = parallel_tasks_count
    
        start_time = time.time()
        for loop_no in range(self.loop_count):
            self.logger.debug('loop_no = {}'.format(loop_no))

            result = self.run_parallel_tasks(self.parallel_tasks_count)
            if self.func_result_ok:
                if not self.func_result_ok(self, result):
                    self.logger.error('result = {}'.format(result))
                    break
                else:
                    self.logger.info('result ok')
                        
        if self.func_after_result:
            self.func_after_result(self)

        run_secs = time.time() - start_time
        self.logger.info('ready in {} seconds'.format(run_secs))

# ### YOUR CODE BELOW ### #

def before_tasks(task_runner):
    # create a table, insert row with counter = 0
    with sqlite3.connect('./test_tasks.db') as conn:
        cursor = conn.cursor()
        cursor.execute("""DROP TABLE IF EXISTS tasks""")
        cursor.execute("""CREATE TABLE tasks (counter INTEGER)""")
        cursor.execute("""INSERT INTO tasks (counter) VALUES (0)""")
        conn.commit()

def task(task_runner, task_no):
    logger = task_runner.get_logger('task' + str(task_no))
    pid = os.getpid()

    # wait for event
    logger.debug('pid = {} waiting for event at {}'.format(pid, time.time()))
    task_runner.mp_event.wait()

    # wait for lock
    lock = task_runner.get_lock()
    logger.debug('pid = {} waiting for lock at  {}'.format(pid, time.time()))
    with lock:
        # increment counter field
        with sqlite3.connect('./test_tasks.db', timeout=10) as conn:
            cursor = conn.cursor()
            counter = cursor.execute('SELECT counter FROM tasks').fetchone()[0]
            logger.debug('counter = {}'.format(counter))
            counter += 1
            cursor.execute("""UPDATE tasks SET counter=?""", (counter,))
            conn.commit()

def after_tasks(task_runner):
    conn = sqlite3.connect('./test_tasks.db')
    cursor = conn.cursor()
    counter = cursor.execute('SELECT counter FROM tasks').fetchone()[0]
    task_runner.logger.info('counter = {} <- final value'.format(counter))

def result_ok(task_runner, result):
    pass

def after_result(task_runner):
    pass

def main():
    tr = TaskRunner(
        # functions
        func_before_tasks=before_tasks,
        func_task=task,
        func_after_tasks=after_tasks,
        #func_result_ok=result_ok,
        func_after_result=after_result,
        # logging
        logger_level=logging.INFO,
        # locking
        use_file_locking=True,
    )
    tr.run(
        loop_count=1,
        parallel_tasks_count=100,
        #parallel_tasks_count=2,
    )

if __name__ == '__main__':
    main()

Summary

We wanted an easy way to test concurrent operations. In the past I used the Python package 'Locust' to test concurrency, see the post 'Using Locust to load test a FastAPI app with concurrent users'. This time I wanted to keep it small, flexible and extensible.
Besides that, I also wanted a multiple processes file lock context manager. We implemented both, the tests passed. Time to get back to my other projects.

Links / credits

Python - fcntl
https://docs.python.org/3/library/fcntl.html

Python - multiprocessing
https://docs.python.org/3/library/multiprocessing.html

Python - SQLite3
https://docs.python.org/3/library/sqlite3.html

Using Locust to load test a FastAPI app with concurrent users
https://www.peterspython.com/en/blog/using-locust-to-load-test-a-fastapi-app-with-concurrent-users

Leave a comment

Comment anonymously or log in to comment.

Comments

Leave a reply

Reply anonymously or log in to reply.