summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/provision.py
blob: a87bb932066500ac08b51cdaf94004b9799d21ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# dialects/postgresql/provision.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors

import time

from ... import exc
from ... import inspect
from ... import text
from ...testing import warn_test_suite
from ...testing.provision import create_db
from ...testing.provision import drop_all_schema_objects_post_tables
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import prepare_for_drop_tables
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import temp_table_keyword_args
from ...testing.provision import upsert


@create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
    template_db = cfg.options.postgresql_templatedb

    with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
        if not template_db:
            template_db = conn.exec_driver_sql(
                "select current_database()"
            ).scalar()

        attempt = 0
        while True:
            try:
                conn.exec_driver_sql(
                    "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
                )
            except exc.OperationalError as err:
                attempt += 1
                if attempt >= 3:
                    raise
                if "accessed by other users" in str(err):
                    log.info(
                        "Waiting to create %s, URI %r, "
                        "template DB %s is in use sleeping for .5",
                        ident,
                        eng.url,
                        template_db,
                    )
                    time.sleep(0.5)
            except:
                raise
            else:
                break


@drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
    with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
        with conn.begin():
            conn.execute(
                text(
                    "select pg_terminate_backend(pid) from pg_stat_activity "
                    "where usename=current_user and pid != pg_backend_pid() "
                    "and datname=:dname"
                ),
                dict(dname=ident),
            )
            conn.exec_driver_sql("DROP DATABASE %s" % ident)


@temp_table_keyword_args.for_db("postgresql")
def _postgresql_temp_table_keyword_args(cfg, eng):
    return {"prefixes": ["TEMPORARY"]}


@set_default_schema_on_connection.for_db("postgresql")
def _postgresql_set_default_schema_on_connection(
    cfg, dbapi_connection, schema_name
):
    existing_autocommit = dbapi_connection.autocommit
    dbapi_connection.autocommit = True
    cursor = dbapi_connection.cursor()
    cursor.execute("SET SESSION search_path='%s'" % schema_name)
    cursor.close()
    dbapi_connection.autocommit = existing_autocommit


@drop_all_schema_objects_pre_tables.for_db("postgresql")
def drop_all_schema_objects_pre_tables(cfg, eng):
    with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
        for xid in conn.exec_driver_sql(
            "select gid from pg_prepared_xacts"
        ).scalars():
            conn.execute("ROLLBACK PREPARED '%s'" % xid)


@drop_all_schema_objects_post_tables.for_db("postgresql")
def drop_all_schema_objects_post_tables(cfg, eng):
    from sqlalchemy.dialects import postgresql

    inspector = inspect(eng)
    with eng.begin() as conn:
        for enum in inspector.get_enums("*"):
            conn.execute(
                postgresql.DropEnumType(
                    postgresql.ENUM(name=enum["name"], schema=enum["schema"])
                )
            )


@prepare_for_drop_tables.for_db("postgresql")
def prepare_for_drop_tables(config, connection):
    """Ensure there are no locks on the current username/database."""

    result = connection.exec_driver_sql(
        "select pid, state, wait_event_type, query "
        # "select pg_terminate_backend(pid), state, wait_event_type "
        "from pg_stat_activity where "
        "usename=current_user "
        "and datname=current_database() and state='idle in transaction' "
        "and pid != pg_backend_pid()"
    )
    rows = result.all()  # noqa
    if rows:
        warn_test_suite(
            "PostgreSQL may not be able to DROP tables due to "
            "idle in transaction: %s"
            % ("; ".join(row._mapping["query"] for row in rows))
        )


@upsert.for_db("postgresql")
def _upsert(
    cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
):
    from sqlalchemy.dialects.postgresql import insert

    stmt = insert(table)

    table_pk = inspect(table).selectable

    if set_lambda:
        stmt = stmt.on_conflict_do_update(
            index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
        )
    else:
        stmt = stmt.on_conflict_do_nothing()

    stmt = stmt.returning(
        *returning, sort_by_parameter_order=sort_by_parameter_order
    )
    return stmt


_extensions = [
    ("citext", (13,)),
    ("hstore", (13,)),
]


@post_configure_engine.for_db("postgresql")
def _create_citext_extension(url, engine, follower_ident):
    with engine.connect() as conn:
        for extension, min_version in _extensions:
            if conn.dialect.server_version_info >= min_version:
                conn.execute(
                    text(f"CREATE EXTENSION IF NOT EXISTS {extension}")
                )
                conn.commit()