sqlalchemy: Auto-generated ids written to the wrong ORM objects causing data corruption with mssql+pyodbc

Describe the bug

It appears that the database auto generated ids are written back to the wrong code-side ORM objects under some specific conditions. This leads to data corruption when those ids (code side) are used to generate new records with reference to the previous. In the example code this seems to affect the Result ORM objects: each one is updated with a database id but in a different order to the records written to the database (giving the ORM object the wrong id). This then leads to incorrect ResultDatum records being written (the result_id is collected from the Result ORM object).

The issue was first noticed in a production system and it has taken some effort to reduce it down to the example given. Key elements seem to be that we are using mssql+pyodbc (I don’t see the issue with sqlite), sqlalchemy 2.0 (does not occur with 1.4.47), a sufficient number of records are written to the database (the example uses 10,000, but does not occur with 100), and the class definition has two orm.relationships defined using different foreign keys (no idea why that’s important but I couldn’t recreate without this).

Optional link from https://docs.sqlalchemy.org which documents the behavior that is expected

No response

SQLAlchemy Version in Use

2.0.8

DBAPI (i.e. the database driver)

mssql+pyodbc (ODBC Driver 17 for SQL Server)

Database Vendor and Major Version

MSSQL 15.0.4298.1-1 amd64

Python Version

Python 3.8.10 (python3/focal,now 3.8.2-0ubuntu2 amd64)

Operating system

Linux (Ubuntu 20.04.5 LTS)

To Reproduce

from contextlib import contextmanager
import logging
import datetime
import random
import sqlalchemy as sql
import sqlalchemy.orm


log = logging.getLogger(__name__)

Base = sql.orm.declarative_base()


class Trial(Base):

    __tablename__ = "trial"

    trial_id = sql.Column(sql.Integer)
    trial_dt = sql.Column(sql.DateTime(timezone=False))

    __table_args__ = (
        sql.PrimaryKeyConstraint(
            trial_id, name="trial__pk", mssql_clustered=True
        ),
    )


class Datum(Base):

    __tablename__ = "datum"

    datum_id = sql.Column(sql.Integer)
    trial_id = sql.Column(sql.ForeignKey(Trial.trial_id))
    datum_num = sql.Column(sql.Integer)
    datum_val = sql.Column(sql.Float)

    trial = sql.orm.relationship(Trial)

    __table_args__ = (
        sql.PrimaryKeyConstraint(
            datum_id, name="datum__pk", mssql_clustered=True
        ),
    )


class Result(Base):

    __tablename__ = "result"

    result_id = sql.Column(sql.Integer)
    trial_id = sql.Column(sql.ForeignKey(Trial.trial_id))
    lft_datum_id = sql.Column(sql.ForeignKey(Datum.datum_id))
    rgt_datum_id = sql.Column(sql.ForeignKey(Datum.datum_id))
    lft_datum_num = sql.Column(sql.Integer)
    rgt_datum_num = sql.Column(sql.Integer)
    result_val = sql.Column(sql.Float)
    rank_val = sql.Column(sql.Integer)

    trial = sql.orm.relationship(Trial)
    lft_datum = sql.orm.relationship(Datum, foreign_keys=[lft_datum_id])
    rgt_datum = sql.orm.relationship(Datum, foreign_keys=[rgt_datum_id])

    __table_args__ = (
        sql.PrimaryKeyConstraint(
            result_id, name="result__pk", mssql_clustered=True
        ),
    )


class ResultDatum(Base):

    __tablename__ = "result_datum"

    result_id = sql.Column(sql.ForeignKey(Result.result_id))
    lft_datum_id = sql.Column(sql.ForeignKey(Datum.datum_id))
    rgt_datum_id = sql.Column(sql.ForeignKey(Datum.datum_id))
    trial_id = sql.Column(sql.ForeignKey(Trial.trial_id))
    result_val = sql.Column(sql.Float)

    trial = sql.orm.relationship(Trial)
    lft_datum = sql.orm.relationship(Datum, foreign_keys=[lft_datum_id])
    rgt_datum = sql.orm.relationship(Datum, foreign_keys=[rgt_datum_id])
    result = sql.orm.relationship(Result)

    __table_args__ = (
        sql.PrimaryKeyConstraint(
            result_id, name="result_datum__pk", mssql_clustered=True
        ),
    )


def get_engine():

    cnxn_url = sql.engine.URL.create(
        # "sqlite",
        # database="sqla_test_01.db",
        "mssql+pyodbc",
        database="sqla_test_01",
        host="localhost",
        username="test-srv",
        password="your_password",
        query={"driver": "ODBC Driver 17 for SQL Server"},
    )

    return sql.create_engine(cnxn_url)


def initialise():

    Base.metadata.create_all(get_engine())


def run_trial(size):

    run = {n: random.random() for n in range(size)}
    data = list((run[n] + run[n]*run[m], (n, m)) for n in run for m in run)
    data.sort(key=lambda x: x[0])
    result = [(n, x, y) for n, (x, y) in enumerate(data)]
    
    return run, result


def run_test(num, size, session):

    for n in range(num):

        run, result = run_trial(size)

        trial = Trial(trial_dt=datetime.datetime.now())
        datum_idx = {
            num: Datum(
                trial=trial, datum_num=num, datum_val=run[num]
            )
            for num in run
        }
        result_idx = {
            r: Result(
                trial=trial,
                lft_datum=datum_idx[n],
                rgt_datum=datum_idx[m],
                lft_datum_num=n, 
                rgt_datum_num=m,
                result_val=v,
                rank_val=r
            )
            for r, v, (n, m) in result
        }

        session.add(trial)
        session.add_all(datum_idx.values())
        session.add_all(result_idx.values())
        session.add_all(
            ResultDatum(
                trial=trial,
                lft_datum=datum_idx[n],
                rgt_datum=datum_idx[m],
                result=result_idx[r],
                result_val=v
            )
            for r, v, (n, m) in result
        )

        session.flush()


def get_test_qry(session):

    return session.query(ResultDatum).join(Result).filter(
        sql.or_(
            Result.lft_datum_id != ResultDatum.lft_datum_id,
            Result.rgt_datum_id != ResultDatum.rgt_datum_id
        )
    )


def process(num=1, size=100):

    initialise()
    session = sql.orm.Session(get_engine())

    run_test(num, size, session)
    session.commit()

    num_bad = get_test_qry(session).count()
    print("There were {0} bad records.".format(num_bad))

Error

# Copy the complete stack trace and error message here, including SQL log output if applicable.

No exceptions occur.

Additional context

You will need a test database set up on a sql server (code refers to sqla_test_01) and a user to connect to this database (code refers to ‘test-srv’). Adjust get_engine() to meet setup requirements. Then run process() in a python console (the only dependencies are python standard library and sqlalchemy 2.0.8). The function process() prints a line “There were {0} bad records.”, if the number printed is anything other than 0, then some corruption has occured.

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 66 (45 by maintainers)

Commits related to this issue

Most upvoted comments

There seem to be no reply in any of the recent posts, so not sure it it’s worth it