channels: Failing test when calling django orm code wrapped in database_sync_to_async

I’m trying to test my Channels consumer which calls database_sync_to_async code. The consumer looks something like this:

class MyConsumer(AsyncJsonWebsocketConsumer):
    async def connect(self):
        my_obj = await self.get_obj()
        ...other code

    @database_sync_to_async
    def get_obj(self):
        return MyModel.objects.get(filter_condition)

The test is using the @pytest.mark.asyncio and @pytest.mark.django_db decorators ie:

@pytest.mark.asyncio
@pytest.mark.django_db
async def test_heartbeat():
    communicator = WebsocketCommunicator(MyConsumer, '<path>')
    await communicator.connect()
    await communicator.disconnect()

I’m using the following command to run the test:

./manage.py test xxx/tests.py::test_heartbeat

The test itself passes, however at the end of the test run I always get the following error:

=============================================== ERRORS ===============================================
________________________________ ERROR at teardown of test_heartbeat _________________________________

self = <django.db.backends.utils.CursorWrapper object at 0x7fbc7b8d80b8>, sql = 'DROP DATABASE "test"'
params = None
ignored_wrapper_args = (False, {'connection': <django.db.backends.postgresql.base.DatabaseWrapper object at 0x7fbc7b0481d0>, 'cursor': <django.db.backends.utils.CursorWrapper object at 0x7fbc7b8d80b8>})

    def _execute(self, sql, params, *ignored_wrapper_args):
        self.db.validate_no_broken_transaction()
        with self.db.wrap_database_errors:
            if params is None:
>               return self.cursor.execute(sql)
E               psycopg2.OperationalError: database "test" is being accessed by other users
E               DETAIL:  There is 1 other session using the database.

I can make the test failure go away by removing all references to database_sync_to_async in the consumer, but my understanding is that is poor practice to have sync code (like calling django orm) running inside an async function.

Strangely when I get the failing test two tests run (one pass and one fail), but when I remove the references to database_sync_to_async only one test runs.

Here are the versions of my libraries:

django==2.0.6
daphne==2.2.0
asgiref==2.3.2
channels-redis==2.2.1
pytest==3.6.1
pytest-asyncio==0.8.0
pytest-django==3.1.2

About this issue

  • Original URL
  • State: open
  • Created 6 years ago
  • Reactions: 6
  • Comments: 33 (17 by maintainers)

Commits related to this issue

Most upvoted comments

I don’t suggest this as a final fix. But we were able to bypass this on django 3.1 + channels 2.4 with such fixture:

@pytest.fixture
def fix_async_db():
    local = connections._connections
    ctx = local._get_context_id()
    for conn in connections.all():
        conn.inc_thread_sharing()
    conn = connections.all()[0]
    old = local._get_context_id
    try:
        with mock.patch.object(conn, 'close'):
            object.__setattr__(local, '_get_context_id', lambda: ctx)
            yield
    finally:
        object.__setattr__(local, '_get_context_id', old)

It is specific for asgi and channels implementation, not threadsafe, not asyncio safe

@vijayshan thanks for your feedback. I did try to use transaction=True in my pytest.mark.django_db, but didn’t seem to help. However I did not try to run your example code exactly, so it’s possible there’s something specific to my database that’s causing those errors.

One thing I noticed was with just a single test case, it passes but as soon as you add more test cases there’s a bigger chance some connections won’t be closed properly resulting in the error message from my original post. Also in case it helps there’s another issue someone created in the pytest-asyncio github that exactly describes my problem: https://github.com/pytest-dev/pytest-asyncio/issues/82. Unfortunately no solution there either.

Anyway in my case I came up with a really hacky fix. It involves calling raw sql to close those connections just before the last test ends. I’m using postgres so my code looks like:

def close_db_connections():
    sql = "select * from pg_stat_activity where datname = 'test'"
    with connection.cursor() as cursor:
        cursor.execute(sql)
        for row in cursor.fetchall()[0:-1]:
            cursor.execute(
                'select pg_terminate_backend({}) from pg_stat_activity'.format(
                    row[2]))

Kind of ugly, I know, but at least now my tests are passing all the time.

@cybergrind Thank you. This garbage is still broken, so this fix is salvation.

I’ve reproduced this problem in a more minimal form with this test repo: https://github.com/adamchainz/channels-bug-connection-closed

The issue only appeared in the client project I was working on when upgrading to asgiref 3.5.1+, which fixed thread-sensitive mode. This fix causes queries from a tested consumer to run on the main thread inside the atomic() from the TestCase.

As a workaround I’ve added this to the project’s test runner class to disable the calls to close_old_connections() in @database_sync_to_async:

        # Channels upstream has a bug where connections get closed during test
        # transactions:
        # https://github.com/django/channels/issues/1091
        # Workaround this by replacing the “maybe close connection” function
        # in channels.db with a no-op
        from channels import db as channels_db

        def no_op():
            pass

        channels_db.close_old_connections = no_op

This issue actually seems to be a repeat of #462. Historically that was fixed in 9ae27cb835763075a81ba3823d20bd68f3ace46b, which made Channels disable its connection-closing logic during tests. Unfortunately this fix seems to have been lost when moving from the testing Client to the newer testing communicators (Client removed in 66bc2981f1c69f35aed9815e5b90d19d5c66387d).

I think a fix should look something like the previous methodolgy, removing the calls to close_old_connections() during tests.

Out of curiosity: I see that nobody here has mentioned CONN_MAX_AGE. It occurs to me that if one sets settings.DATABASES['default']['CONN_MAX_AGE']=0, the error in this bug should never happen (because @database_sync_to_async should close its connection each time it’s called). CONN_MAX_AGE=0 may be a viable approach for some people.

@tomek-rej Have you tried setting the transaction parameter on the pytest.mark.django_db decorator to true?

This is a sample test: `@pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_my_consumer_normal():

user = User.objects.get(username="vijayshan")
assert user is not None
user1 = User.objects.get(username="shan")
assert user1 is not None



communicator = create_communicator_for_url_and_user("/ws/chat/room1/vijayshan/", user)
communicator1 = create_communicator_for_url_and_user( "/ws/chat/room1/vijayshan1/", user1)
connected, subprotocol = await communicator.connect()
connected1, subprotocol = await communicator1.connect()
assert connected is True
assert connected1 is True
room = Room.objects.get(name="room1")
assert room is not None
# Test sending text
message_sent = "hello"
await communicator.send_to(text_data=json.dumps({"message": message_sent}))
await confirm_response_for_communicator(communicator, "vijayshan: hello")
await confirm_response_for_communicator(communicator1,"vijayshan: hello")

msg = Message.objects.get(message=message_sent)
assert msg.special_message == False
assert msg.message == message_sent
assert msg.room_id == room.id
assert msg.user_id == user.id

await communicator.disconnect()
await communicator1.disconnect()`

and this is the consumer: `class ChatConsumer(AsyncWebsocketConsumer): “”" ChatConsumer: The chat consumer class is responsible for all chat interactions “”"

async def connect(self):
    """
    the method that handles the initial connections. the method captures the initial room name
    and generates a room group name. if the scope has a user name then the connection is accepted
    else it is rejected. a unique id for the user is also generated.
    :return:
    """
    self.requires_history = True
    self.room_name = self.scope['url_route']['kwargs']['room_name']

    self.room_group_name = 'chat_%s' % self.room_name

    if("user" in self.scope.keys() and self.scope['user'] is not None):
        self.user_name = self.scope['user'].username

        room, created = await self.retrieve_room_name(self.room_name)
        self.room = room
        self.user = self.scope['user']
        self._id = str(uuid.uuid4())
        # Join room group
        await self.channel_layer.group_add(
            self.room_group_name,
            self.channel_name
        )

        await self.accept()
    else:
        await self.close()

@database_sync_to_async
def retrieve_room_name(self, room_name):
    """
    This method retrieves the room name name from the db if it exists else creates one
    :param room_name: the room name that has been requested
    :return: the room name from the db and if it has been created or did it already exist
    """
    room, created = Room.objects.get_or_create(name=room_name)
    return room, created

def add_message(self, message, is_special=False):
    """
    adds a message to the database
    :param message: the message text to be added
    :param is_special: is this is a special message or not
    :return: return the message object that was just added
    """
    mess = Message(message=message, room=self.room, user=self.user, special_message=is_special)
    mess.save()
    return mess

async def disconnect(self, close_code):
    """
    Leave the room
    :param close_code:
    :return:
    """
    await self.channel_layer.group_discard(
        self.room_group_name,
        self.channel_name
    )

# Receive message from WebSocket

async def receive(self, text_data=None, bytes_data=None):
    text_data_json = json.loads(text_data)
    message = text_data_json["message"]
    if "special_message" in message:
        # Send message to room group

        await self.handle_special_message(message)
    else:
        await self.handle_regular_message(message)

async def handle_regular_message(self, message):
    message_obj = await database_sync_to_async(self.add_message)(
        message=message
    )
    message = "{}: {}".format(self.user_name, message_obj.message)
    # Send message to room group
    await self.channel_layer.group_send(
        self.room_group_name,
        {
            'type': 'chat_message',
            'message': message,
            'is_special': False
        }
    )

async def handle_special_message(self, message):
    message_obj = await database_sync_to_async(self.add_message)(
        message=message.replace("special_message ", ""),
        is_special=True
    )
    message = "{}: {}".format(self.user_name, message_obj.message)
    await self.channel_layer.group_send(
        self.room_group_name,
        {
            'type': 'special_message',
            'message': message,
            'id': self._id,
            'is_special': True
        }
    )

# Receive message from room group
async def chat_message(self, event):
    message = event['message']
    # Send message to WebSocket
    await self.send(text_data=json.dumps({
        'type': 'chat_message',
        'message': message,
    }))

async def special_message(self,event):
    message = event['message']
    id = event['id']
    # Send message to socket
    if self._id != id:
        await self.send(text_data=json.dumps({
            'type': 'chat_message',
            'message': message,
        }))

` My tests are passing.

I have transitioned our project to a different approach for unit tests: use the main-thread database connection. This makes tests faster since they don’t need to disconnect/reconnect (assuming you’ve configured Django to persist connections – which is a valid assumption, because if you didn’t you wouldn’t be seeing this error, right?).

This means … er … replacing @database_sync_to_async. I have no compunction about doing this: I don’t believe database connections should run on the default thread pool. (This lets you configure the number of database connections your Django app maintains. That’s a good thing.)

In cjworkbench/sync.py, we have:

import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
import functools
from channels.db import DatabaseSyncToAsync
from django.conf import settings


class OurDatabaseSyncToAsync(DatabaseSyncToAsync):
    """
    SyncToAsync on a special, database-only threadpool.
    Each thread has zero (on startup) or one (forever) database connection,
    stored in thread-local `django.db.connections[DEFAULT_DB_ALIAS]`.
    
    There is no way to close the threads' connections.

    This is how Channels' database_sync_to_async _should_ be implemented.
    We don't want ASGI_THREADS-many database connections because they thrash
    the database. Fewer connections means higher throughput. (We don't have any
    long-living SQL transactions; they'd change this calculus.)
    """

    executor = ThreadPoolExecutor(
        max_workers=settings.N_SYNC_DATABASE_CONNECTIONS,
        thread_name_prefix='our-database-sync-to-async-',
    )

    # override
    async def __call__(self, *args, **kwargs):
        # re-implementation of async_to_sync
        loop = asyncio.get_event_loop()
        context = contextvars.copy_context()
        child = functools.partial(self.func, *args, **kwargs)

        future = loop.run_in_executor(
            self.executor,
            functools.partial(
                self.thread_handler,
                loop,
                self.get_current_task(),
                context.run,
                child
            ),
        )
        return await asyncio.wait_for(future, timeout=None)


# The class is TitleCased, but we want to encourage use as a callable/decorator
database_sync_to_async = OurDatabaseSyncToAsync

In our tests/utils.py file, we have:

# Connect to the database, on the main thread, and remember that connection
main_thread_connections = {name: connections[name] for name in connections}


def _inherit_main_thread_connections():
    for name in main_thread_connections:
        connections[name] = main_thread_connections[name]
        connections[name].allow_thread_sharing = True


class DbTestCase(SimpleTestCase):
    allow_database_queries = True

    # run_with_async_db() tasks all share a single database connection. To
    # avoid concurrency issues, run them all in a single thread.
    #
    # Assumes DB connections may be passed between threads. (Only one thread
    # will make DB calls at a time.)
    async_executor = ThreadPoolExecutor(
        max_workers=1,
        thread_name_prefix='run_with_async_db_thread',
        initializer=_inherit_main_thread_connections
    )

    # Don't bother clearing data in tearDown(). The next test that needs the
    # database will be running setUp() anyway, so extra clearing will only cost
    # time.

    def run_with_async_db(self, task):
        """
        Runs async tasks, using the main thread's database connection.
        See
        https://github.com/django/channels/issues/1091#issuecomment-436067763.
        """
        # We'll execute with a 1-worker thread pool, shared between tests. We
        # need to limit to 1 worker, because all workers share the same
        # database connection.
        #
        # This hack is just for unit tests: the test suite will end with a
        # "delete the entire database" call, and we want it to succeed; that
        # means there need to be no other connections using the database.
        old_loop = asyncio.get_event_loop()
        old_executor = OurDatabaseSyncToAsync.executor
        asyncio.set_event_loop(None)
        try:
            OurDatabaseSyncToAsync.executor = self.async_executor
            return asyncio.run(task)
        finally:
            OurDatabaseSyncToAsync.executor = old_executor
            asyncio.set_event_loop(old_loop)

Now:

  • Every place where you use @database_sync_to_async, import it from your own sync module – not from channels.
  • When testing async code, extend DbTestCase and run self.run_with_async_db(async_method(*args, **kwargs)) The calling convention mimics asyncio.run(task(...)), not async_to_sync(task)(...).

In my opinion, the separate-thread-pool approach should be part of Channels proper. I’m less certain about how Channels proper should address unit testing.

Here’s my ugly hack:

    def setUp(self):
        super().setUp()

        # We'll execute with a 1-worker thread pool. That's because Django
        # database methods will spin up new connections and never close them.
        # (@database_sync_to_async -- which execute uses --only closes _old_
        # connections, not valid ones.)
        #
        # This hack is just for unit tests: we need to close all connections
        # before the test ends, so we can delete the entire database when tests
        # finish. We'll schedule the "close-connection" operation on the same
        # thread as @database_sync_to_async's blocking code ran on. That way,
        # it'll close the connection @database_sync_to_async was using.
        self._old_loop = asyncio.get_event_loop()
        self.loop = asyncio.new_event_loop()
        self.loop.set_default_executor(ThreadPoolExecutor(1))
        asyncio.set_event_loop(self.loop)

    # Be careful, in these tests, not to run database queries in async blocks.
    def tearDown(self):
        def close_thread_connection():
            # Close the connection that was created by @database_sync_to_async.
            # Assumes we're running in the same thread that ran the database
            # stuff.
            django.db.connections.close_all()

        self.loop.run_in_executor(None, close_thread_connection)

        asyncio.set_event_loop(self._old_loop)

        super().tearDown()

… and then instead of async_to_async(func)(params), I run self.loop.run_until_complete(func(params)).

This sort of test case should be the exception, not the norm. You can usually structure your code so that each single test either tests scheduling or database queries but not both.