1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306 |
- """Miscellaneous goodies for psycopg2
- This module is a generic place used to hold little helper functions
- and classes until a better place in the distribution is found.
- """
- # psycopg/extras.py - miscellaneous extra goodies for psycopg
- #
- # Copyright (C) 2003-2019 Federico Di Gregorio <fog@debian.org>
- # Copyright (C) 2020-2021 The Psycopg Team
- #
- # psycopg2 is free software: you can redistribute it and/or modify it
- # under the terms of the GNU Lesser General Public License as published
- # by the Free Software Foundation, either version 3 of the License, or
- # (at your option) any later version.
- #
- # In addition, as a special exception, the copyright holders give
- # permission to link this program with the OpenSSL library (or with
- # modified versions of OpenSSL that use the same license as OpenSSL),
- # and distribute linked combinations including the two.
- #
- # You must obey the GNU Lesser General Public License in all respects for
- # all of the code used other than OpenSSL.
- #
- # psycopg2 is distributed in the hope that it will be useful, but WITHOUT
- # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
- # License for more details.
- import os as _os
- import time as _time
- import re as _re
- from collections import namedtuple, OrderedDict
- import logging as _logging
- import psycopg2
- from psycopg2 import extensions as _ext
- from .extensions import cursor as _cursor
- from .extensions import connection as _connection
- from .extensions import adapt as _A, quote_ident
- from functools import lru_cache
- from psycopg2._psycopg import ( # noqa
- REPLICATION_PHYSICAL, REPLICATION_LOGICAL,
- ReplicationConnection as _replicationConnection,
- ReplicationCursor as _replicationCursor,
- ReplicationMessage)
- # expose the json adaptation stuff into the module
- from psycopg2._json import ( # noqa
- json, Json, register_json, register_default_json, register_default_jsonb)
- # Expose range-related objects
- from psycopg2._range import ( # noqa
- Range, NumericRange, DateRange, DateTimeRange, DateTimeTZRange,
- register_range, RangeAdapter, RangeCaster)
- # Expose ipaddress-related objects
- from psycopg2._ipaddress import register_ipaddress # noqa
- class DictCursorBase(_cursor):
- """Base class for all dict-like cursors."""
- def __init__(self, *args, **kwargs):
- if 'row_factory' in kwargs:
- row_factory = kwargs['row_factory']
- del kwargs['row_factory']
- else:
- raise NotImplementedError(
- "DictCursorBase can't be instantiated without a row factory.")
- super().__init__(*args, **kwargs)
- self._query_executed = False
- self._prefetch = False
- self.row_factory = row_factory
- def fetchone(self):
- if self._prefetch:
- res = super().fetchone()
- if self._query_executed:
- self._build_index()
- if not self._prefetch:
- res = super().fetchone()
- return res
- def fetchmany(self, size=None):
- if self._prefetch:
- res = super().fetchmany(size)
- if self._query_executed:
- self._build_index()
- if not self._prefetch:
- res = super().fetchmany(size)
- return res
- def fetchall(self):
- if self._prefetch:
- res = super().fetchall()
- if self._query_executed:
- self._build_index()
- if not self._prefetch:
- res = super().fetchall()
- return res
- def __iter__(self):
- try:
- if self._prefetch:
- res = super().__iter__()
- first = next(res)
- if self._query_executed:
- self._build_index()
- if not self._prefetch:
- res = super().__iter__()
- first = next(res)
- yield first
- while True:
- yield next(res)
- except StopIteration:
- return
- class DictConnection(_connection):
- """A connection that uses `DictCursor` automatically."""
- def cursor(self, *args, **kwargs):
- kwargs.setdefault('cursor_factory', self.cursor_factory or DictCursor)
- return super().cursor(*args, **kwargs)
- class DictCursor(DictCursorBase):
- """A cursor that keeps a list of column name -> index mappings__.
- .. __: https://docs.python.org/glossary.html#term-mapping
- """
- def __init__(self, *args, **kwargs):
- kwargs['row_factory'] = DictRow
- super().__init__(*args, **kwargs)
- self._prefetch = True
- def execute(self, query, vars=None):
- self.index = OrderedDict()
- self._query_executed = True
- return super().execute(query, vars)
- def callproc(self, procname, vars=None):
- self.index = OrderedDict()
- self._query_executed = True
- return super().callproc(procname, vars)
- def _build_index(self):
- if self._query_executed and self.description:
- for i in range(len(self.description)):
- self.index[self.description[i][0]] = i
- self._query_executed = False
- class DictRow(list):
- """A row object that allow by-column-name access to data."""
- __slots__ = ('_index',)
- def __init__(self, cursor):
- self._index = cursor.index
- self[:] = [None] * len(cursor.description)
- def __getitem__(self, x):
- if not isinstance(x, (int, slice)):
- x = self._index[x]
- return super().__getitem__(x)
- def __setitem__(self, x, v):
- if not isinstance(x, (int, slice)):
- x = self._index[x]
- super().__setitem__(x, v)
- def items(self):
- g = super().__getitem__
- return ((n, g(self._index[n])) for n in self._index)
- def keys(self):
- return iter(self._index)
- def values(self):
- g = super().__getitem__
- return (g(self._index[n]) for n in self._index)
- def get(self, x, default=None):
- try:
- return self[x]
- except Exception:
- return default
- def copy(self):
- return OrderedDict(self.items())
- def __contains__(self, x):
- return x in self._index
- def __reduce__(self):
- # this is apparently useless, but it fixes #1073
- return super().__reduce__()
- def __getstate__(self):
- return self[:], self._index.copy()
- def __setstate__(self, data):
- self[:] = data[0]
- self._index = data[1]
- class RealDictConnection(_connection):
- """A connection that uses `RealDictCursor` automatically."""
- def cursor(self, *args, **kwargs):
- kwargs.setdefault('cursor_factory', self.cursor_factory or RealDictCursor)
- return super().cursor(*args, **kwargs)
- class RealDictCursor(DictCursorBase):
- """A cursor that uses a real dict as the base type for rows.
- Note that this cursor is extremely specialized and does not allow
- the normal access (using integer indices) to fetched data. If you need
- to access database rows both as a dictionary and a list, then use
- the generic `DictCursor` instead of `!RealDictCursor`.
- """
- def __init__(self, *args, **kwargs):
- kwargs['row_factory'] = RealDictRow
- super().__init__(*args, **kwargs)
- def execute(self, query, vars=None):
- self.column_mapping = []
- self._query_executed = True
- return super().execute(query, vars)
- def callproc(self, procname, vars=None):
- self.column_mapping = []
- self._query_executed = True
- return super().callproc(procname, vars)
- def _build_index(self):
- if self._query_executed and self.description:
- self.column_mapping = [d[0] for d in self.description]
- self._query_executed = False
- class RealDictRow(OrderedDict):
- """A `!dict` subclass representing a data record."""
- def __init__(self, *args, **kwargs):
- if args and isinstance(args[0], _cursor):
- cursor = args[0]
- args = args[1:]
- else:
- cursor = None
- super().__init__(*args, **kwargs)
- if cursor is not None:
- # Required for named cursors
- if cursor.description and not cursor.column_mapping:
- cursor._build_index()
- # Store the cols mapping in the dict itself until the row is fully
- # populated, so we don't need to add attributes to the class
- # (hence keeping its maintenance, special pickle support, etc.)
- self[RealDictRow] = cursor.column_mapping
- def __setitem__(self, key, value):
- if RealDictRow in self:
- # We are in the row building phase
- mapping = self[RealDictRow]
- super().__setitem__(mapping[key], value)
- if key == len(mapping) - 1:
- # Row building finished
- del self[RealDictRow]
- return
- super().__setitem__(key, value)
- class NamedTupleConnection(_connection):
- """A connection that uses `NamedTupleCursor` automatically."""
- def cursor(self, *args, **kwargs):
- kwargs.setdefault('cursor_factory', self.cursor_factory or NamedTupleCursor)
- return super().cursor(*args, **kwargs)
- class NamedTupleCursor(_cursor):
- """A cursor that generates results as `~collections.namedtuple`.
- `!fetch*()` methods will return named tuples instead of regular tuples, so
- their elements can be accessed both as regular numeric items as well as
- attributes.
- >>> nt_cur = conn.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
- >>> rec = nt_cur.fetchone()
- >>> rec
- Record(id=1, num=100, data="abc'def")
- >>> rec[1]
- 100
- >>> rec.data
- "abc'def"
- """
- Record = None
- MAX_CACHE = 1024
- def execute(self, query, vars=None):
- self.Record = None
- return super().execute(query, vars)
- def executemany(self, query, vars):
- self.Record = None
- return super().executemany(query, vars)
- def callproc(self, procname, vars=None):
- self.Record = None
- return super().callproc(procname, vars)
- def fetchone(self):
- t = super().fetchone()
- if t is not None:
- nt = self.Record
- if nt is None:
- nt = self.Record = self._make_nt()
- return nt._make(t)
- def fetchmany(self, size=None):
- ts = super().fetchmany(size)
- nt = self.Record
- if nt is None:
- nt = self.Record = self._make_nt()
- return list(map(nt._make, ts))
- def fetchall(self):
- ts = super().fetchall()
- nt = self.Record
- if nt is None:
- nt = self.Record = self._make_nt()
- return list(map(nt._make, ts))
- def __iter__(self):
- try:
- it = super().__iter__()
- t = next(it)
- nt = self.Record
- if nt is None:
- nt = self.Record = self._make_nt()
- yield nt._make(t)
- while True:
- yield nt._make(next(it))
- except StopIteration:
- return
- # ascii except alnum and underscore
- _re_clean = _re.compile(
- '[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']')
- def _make_nt(self):
- key = tuple(d[0] for d in self.description) if self.description else ()
- return self._cached_make_nt(key)
- @classmethod
- def _do_make_nt(cls, key):
- fields = []
- for s in key:
- s = cls._re_clean.sub('_', s)
- # Python identifier cannot start with numbers, namedtuple fields
- # cannot start with underscore. So...
- if s[0] == '_' or '0' <= s[0] <= '9':
- s = 'f' + s
- fields.append(s)
- nt = namedtuple("Record", fields)
- return nt
- @lru_cache(512)
- def _cached_make_nt(cls, key):
- return cls._do_make_nt(key)
- # Exposed for testability, and if someone wants to monkeypatch to tweak
- # the cache size.
- NamedTupleCursor._cached_make_nt = classmethod(_cached_make_nt)
- class LoggingConnection(_connection):
- """A connection that logs all queries to a file or logger__ object.
- .. __: https://docs.python.org/library/logging.html
- """
- def initialize(self, logobj):
- """Initialize the connection to log to `!logobj`.
- The `!logobj` parameter can be an open file object or a Logger/LoggerAdapter
- instance from the standard logging module.
- """
- self._logobj = logobj
- if _logging and isinstance(
- logobj, (_logging.Logger, _logging.LoggerAdapter)):
- self.log = self._logtologger
- else:
- self.log = self._logtofile
- def filter(self, msg, curs):
- """Filter the query before logging it.
- This is the method to overwrite to filter unwanted queries out of the
- log or to add some extra data to the output. The default implementation
- just does nothing.
- """
- return msg
- def _logtofile(self, msg, curs):
- msg = self.filter(msg, curs)
- if msg:
- if isinstance(msg, bytes):
- msg = msg.decode(_ext.encodings[self.encoding], 'replace')
- self._logobj.write(msg + _os.linesep)
- def _logtologger(self, msg, curs):
- msg = self.filter(msg, curs)
- if msg:
- self._logobj.debug(msg)
- def _check(self):
- if not hasattr(self, '_logobj'):
- raise self.ProgrammingError(
- "LoggingConnection object has not been initialize()d")
- def cursor(self, *args, **kwargs):
- self._check()
- kwargs.setdefault('cursor_factory', self.cursor_factory or LoggingCursor)
- return super().cursor(*args, **kwargs)
- class LoggingCursor(_cursor):
- """A cursor that logs queries using its connection logging facilities."""
- def execute(self, query, vars=None):
- try:
- return super().execute(query, vars)
- finally:
- self.connection.log(self.query, self)
- def callproc(self, procname, vars=None):
- try:
- return super().callproc(procname, vars)
- finally:
- self.connection.log(self.query, self)
- class MinTimeLoggingConnection(LoggingConnection):
- """A connection that logs queries based on execution time.
- This is just an example of how to sub-class `LoggingConnection` to
- provide some extra filtering for the logged queries. Both the
- `initialize()` and `filter()` methods are overwritten to make sure
- that only queries executing for more than ``mintime`` ms are logged.
- Note that this connection uses the specialized cursor
- `MinTimeLoggingCursor`.
- """
- def initialize(self, logobj, mintime=0):
- LoggingConnection.initialize(self, logobj)
- self._mintime = mintime
- def filter(self, msg, curs):
- t = (_time.time() - curs.timestamp) * 1000
- if t > self._mintime:
- if isinstance(msg, bytes):
- msg = msg.decode(_ext.encodings[self.encoding], 'replace')
- return f"{msg}{_os.linesep} (execution time: {t} ms)"
- def cursor(self, *args, **kwargs):
- kwargs.setdefault('cursor_factory',
- self.cursor_factory or MinTimeLoggingCursor)
- return LoggingConnection.cursor(self, *args, **kwargs)
- class MinTimeLoggingCursor(LoggingCursor):
- """The cursor sub-class companion to `MinTimeLoggingConnection`."""
- def execute(self, query, vars=None):
- self.timestamp = _time.time()
- return LoggingCursor.execute(self, query, vars)
- def callproc(self, procname, vars=None):
- self.timestamp = _time.time()
- return LoggingCursor.callproc(self, procname, vars)
- class LogicalReplicationConnection(_replicationConnection):
- def __init__(self, *args, **kwargs):
- kwargs['replication_type'] = REPLICATION_LOGICAL
- super().__init__(*args, **kwargs)
- class PhysicalReplicationConnection(_replicationConnection):
- def __init__(self, *args, **kwargs):
- kwargs['replication_type'] = REPLICATION_PHYSICAL
- super().__init__(*args, **kwargs)
- class StopReplication(Exception):
- """
- Exception used to break out of the endless loop in
- `~ReplicationCursor.consume_stream()`.
- Subclass of `~exceptions.Exception`. Intentionally *not* inherited from
- `~psycopg2.Error` as occurrence of this exception does not indicate an
- error.
- """
- pass
- class ReplicationCursor(_replicationCursor):
- """A cursor used for communication on replication connections."""
- def create_replication_slot(self, slot_name, slot_type=None, output_plugin=None):
- """Create streaming replication slot."""
- command = f"CREATE_REPLICATION_SLOT {quote_ident(slot_name, self)} "
- if slot_type is None:
- slot_type = self.connection.replication_type
- if slot_type == REPLICATION_LOGICAL:
- if output_plugin is None:
- raise psycopg2.ProgrammingError(
- "output plugin name is required to create "
- "logical replication slot")
- command += f"LOGICAL {quote_ident(output_plugin, self)}"
- elif slot_type == REPLICATION_PHYSICAL:
- if output_plugin is not None:
- raise psycopg2.ProgrammingError(
- "cannot specify output plugin name when creating "
- "physical replication slot")
- command += "PHYSICAL"
- else:
- raise psycopg2.ProgrammingError(
- f"unrecognized replication type: {repr(slot_type)}")
- self.execute(command)
- def drop_replication_slot(self, slot_name):
- """Drop streaming replication slot."""
- command = f"DROP_REPLICATION_SLOT {quote_ident(slot_name, self)}"
- self.execute(command)
- def start_replication(
- self, slot_name=None, slot_type=None, start_lsn=0,
- timeline=0, options=None, decode=False, status_interval=10):
- """Start replication stream."""
- command = "START_REPLICATION "
- if slot_type is None:
- slot_type = self.connection.replication_type
- if slot_type == REPLICATION_LOGICAL:
- if slot_name:
- command += f"SLOT {quote_ident(slot_name, self)} "
- else:
- raise psycopg2.ProgrammingError(
- "slot name is required for logical replication")
- command += "LOGICAL "
- elif slot_type == REPLICATION_PHYSICAL:
- if slot_name:
- command += f"SLOT {quote_ident(slot_name, self)} "
- # don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX
- else:
- raise psycopg2.ProgrammingError(
- f"unrecognized replication type: {repr(slot_type)}")
- if type(start_lsn) is str:
- lsn = start_lsn.split('/')
- lsn = f"{int(lsn[0], 16):X}/{int(lsn[1], 16):08X}"
- else:
- lsn = f"{start_lsn >> 32 & 4294967295:X}/{start_lsn & 4294967295:08X}"
- command += lsn
- if timeline != 0:
- if slot_type == REPLICATION_LOGICAL:
- raise psycopg2.ProgrammingError(
- "cannot specify timeline for logical replication")
- command += f" TIMELINE {timeline}"
- if options:
- if slot_type == REPLICATION_PHYSICAL:
- raise psycopg2.ProgrammingError(
- "cannot specify output plugin options for physical replication")
- command += " ("
- for k, v in options.items():
- if not command.endswith('('):
- command += ", "
- command += f"{quote_ident(k, self)} {_A(str(v))}"
- command += ")"
- self.start_replication_expert(
- command, decode=decode, status_interval=status_interval)
- # allows replication cursors to be used in select.select() directly
- def fileno(self):
- return self.connection.fileno()
- # a dbtype and adapter for Python UUID type
- class UUID_adapter:
- """Adapt Python's uuid.UUID__ type to PostgreSQL's uuid__.
- .. __: https://docs.python.org/library/uuid.html
- .. __: https://www.postgresql.org/docs/current/static/datatype-uuid.html
- """
- def __init__(self, uuid):
- self._uuid = uuid
- def __conform__(self, proto):
- if proto is _ext.ISQLQuote:
- return self
- def getquoted(self):
- return (f"'{self._uuid}'::uuid").encode('utf8')
- def __str__(self):
- return f"'{self._uuid}'::uuid"
- def register_uuid(oids=None, conn_or_curs=None):
- """Create the UUID type and an uuid.UUID adapter.
- :param oids: oid for the PostgreSQL :sql:`uuid` type, or 2-items sequence
- with oids of the type and the array. If not specified, use PostgreSQL
- standard oids.
- :param conn_or_curs: where to register the typecaster. If not specified,
- register it globally.
- """
- import uuid
- if not oids:
- oid1 = 2950
- oid2 = 2951
- elif isinstance(oids, (list, tuple)):
- oid1, oid2 = oids
- else:
- oid1 = oids
- oid2 = 2951
- _ext.UUID = _ext.new_type((oid1, ), "UUID",
- lambda data, cursor: data and uuid.UUID(data) or None)
- _ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID)
- _ext.register_type(_ext.UUID, conn_or_curs)
- _ext.register_type(_ext.UUIDARRAY, conn_or_curs)
- _ext.register_adapter(uuid.UUID, UUID_adapter)
- return _ext.UUID
- # a type, dbtype and adapter for PostgreSQL inet type
- class Inet:
- """Wrap a string to allow for correct SQL-quoting of inet values.
- Note that this adapter does NOT check the passed value to make
- sure it really is an inet-compatible address but DOES call adapt()
- on it to make sure it is impossible to execute an SQL-injection
- by passing an evil value to the initializer.
- """
- def __init__(self, addr):
- self.addr = addr
- def __repr__(self):
- return f"{self.__class__.__name__}({self.addr!r})"
- def prepare(self, conn):
- self._conn = conn
- def getquoted(self):
- obj = _A(self.addr)
- if hasattr(obj, 'prepare'):
- obj.prepare(self._conn)
- return obj.getquoted() + b"::inet"
- def __conform__(self, proto):
- if proto is _ext.ISQLQuote:
- return self
- def __str__(self):
- return str(self.addr)
- def register_inet(oid=None, conn_or_curs=None):
- """Create the INET type and an Inet adapter.
- :param oid: oid for the PostgreSQL :sql:`inet` type, or 2-items sequence
- with oids of the type and the array. If not specified, use PostgreSQL
- standard oids.
- :param conn_or_curs: where to register the typecaster. If not specified,
- register it globally.
- """
- import warnings
- warnings.warn(
- "the inet adapter is deprecated, it's not very useful",
- DeprecationWarning)
- if not oid:
- oid1 = 869
- oid2 = 1041
- elif isinstance(oid, (list, tuple)):
- oid1, oid2 = oid
- else:
- oid1 = oid
- oid2 = 1041
- _ext.INET = _ext.new_type((oid1, ), "INET",
- lambda data, cursor: data and Inet(data) or None)
- _ext.INETARRAY = _ext.new_array_type((oid2, ), "INETARRAY", _ext.INET)
- _ext.register_type(_ext.INET, conn_or_curs)
- _ext.register_type(_ext.INETARRAY, conn_or_curs)
- return _ext.INET
- def wait_select(conn):
- """Wait until a connection or cursor has data available.
- The function is an example of a wait callback to be registered with
- `~psycopg2.extensions.set_wait_callback()`. This function uses
- :py:func:`~select.select()` to wait for data to become available, and
- therefore is able to handle/receive SIGINT/KeyboardInterrupt.
- """
- import select
- from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE
- while True:
- try:
- state = conn.poll()
- if state == POLL_OK:
- break
- elif state == POLL_READ:
- select.select([conn.fileno()], [], [])
- elif state == POLL_WRITE:
- select.select([], [conn.fileno()], [])
- else:
- raise conn.OperationalError(f"bad state from poll: {state}")
- except KeyboardInterrupt:
- conn.cancel()
- # the loop will be broken by a server error
- continue
- def _solve_conn_curs(conn_or_curs):
- """Return the connection and a DBAPI cursor from a connection or cursor."""
- if conn_or_curs is None:
- raise psycopg2.ProgrammingError("no connection or cursor provided")
- if hasattr(conn_or_curs, 'execute'):
- conn = conn_or_curs.connection
- curs = conn.cursor(cursor_factory=_cursor)
- else:
- conn = conn_or_curs
- curs = conn.cursor(cursor_factory=_cursor)
- return conn, curs
- class HstoreAdapter:
- """Adapt a Python dict to the hstore syntax."""
- def __init__(self, wrapped):
- self.wrapped = wrapped
- def prepare(self, conn):
- self.conn = conn
- # use an old-style getquoted implementation if required
- if conn.info.server_version < 90000:
- self.getquoted = self._getquoted_8
- def _getquoted_8(self):
- """Use the operators available in PG pre-9.0."""
- if not self.wrapped:
- return b"''::hstore"
- adapt = _ext.adapt
- rv = []
- for k, v in self.wrapped.items():
- k = adapt(k)
- k.prepare(self.conn)
- k = k.getquoted()
- if v is not None:
- v = adapt(v)
- v.prepare(self.conn)
- v = v.getquoted()
- else:
- v = b'NULL'
- # XXX this b'ing is painfully inefficient!
- rv.append(b"(" + k + b" => " + v + b")")
- return b"(" + b'||'.join(rv) + b")"
- def _getquoted_9(self):
- """Use the hstore(text[], text[]) function."""
- if not self.wrapped:
- return b"''::hstore"
- k = _ext.adapt(list(self.wrapped.keys()))
- k.prepare(self.conn)
- v = _ext.adapt(list(self.wrapped.values()))
- v.prepare(self.conn)
- return b"hstore(" + k.getquoted() + b", " + v.getquoted() + b")"
- getquoted = _getquoted_9
- _re_hstore = _re.compile(r"""
- # hstore key:
- # a string of normal or escaped chars
- "((?: [^"\\] | \\. )*)"
- \s*=>\s* # hstore value
- (?:
- NULL # the value can be null - not catched
- # or a quoted string like the key
- | "((?: [^"\\] | \\. )*)"
- )
- (?:\s*,\s*|$) # pairs separated by comma or end of string.
- """, _re.VERBOSE)
- @classmethod
- def parse(self, s, cur, _bsdec=_re.compile(r"\\(.)")):
- """Parse an hstore representation in a Python string.
- The hstore is represented as something like::
- "a"=>"1", "b"=>"2"
- with backslash-escaped strings.
- """
- if s is None:
- return None
- rv = {}
- start = 0
- for m in self._re_hstore.finditer(s):
- if m is None or m.start() != start:
- raise psycopg2.InterfaceError(
- f"error parsing hstore pair at char {start}")
- k = _bsdec.sub(r'\1', m.group(1))
- v = m.group(2)
- if v is not None:
- v = _bsdec.sub(r'\1', v)
- rv[k] = v
- start = m.end()
- if start < len(s):
- raise psycopg2.InterfaceError(
- f"error parsing hstore: unparsed data after char {start}")
- return rv
- @classmethod
- def parse_unicode(self, s, cur):
- """Parse an hstore returning unicode keys and values."""
- if s is None:
- return None
- s = s.decode(_ext.encodings[cur.connection.encoding])
- return self.parse(s, cur)
- @classmethod
- def get_oids(self, conn_or_curs):
- """Return the lists of OID of the hstore and hstore[] types.
- """
- conn, curs = _solve_conn_curs(conn_or_curs)
- # Store the transaction status of the connection to revert it after use
- conn_status = conn.status
- # column typarray not available before PG 8.3
- typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
- rv0, rv1 = [], []
- # get the oid for the hstore
- curs.execute(f"""SELECT t.oid, {typarray}
- FROM pg_type t JOIN pg_namespace ns
- ON typnamespace = ns.oid
- WHERE typname = 'hstore';
- """)
- for oids in curs:
- rv0.append(oids[0])
- rv1.append(oids[1])
- # revert the status of the connection as before the command
- if (conn_status != _ext.STATUS_IN_TRANSACTION
- and not conn.autocommit):
- conn.rollback()
- return tuple(rv0), tuple(rv1)
- def register_hstore(conn_or_curs, globally=False, unicode=False,
- oid=None, array_oid=None):
- r"""Register adapter and typecaster for `!dict`\-\ |hstore| conversions.
- :param conn_or_curs: a connection or cursor: the typecaster will be
- registered only on this object unless *globally* is set to `!True`
- :param globally: register the adapter globally, not only on *conn_or_curs*
- :param unicode: if `!True`, keys and values returned from the database
- will be `!unicode` instead of `!str`. The option is not available on
- Python 3
- :param oid: the OID of the |hstore| type if known. If not, it will be
- queried on *conn_or_curs*.
- :param array_oid: the OID of the |hstore| array type if known. If not, it
- will be queried on *conn_or_curs*.
- The connection or cursor passed to the function will be used to query the
- database and look for the OID of the |hstore| type (which may be different
- across databases). If querying is not desirable (e.g. with
- :ref:`asynchronous connections <async-support>`) you may specify it in the
- *oid* parameter, which can be found using a query such as :sql:`SELECT
- 'hstore'::regtype::oid`. Analogously you can obtain a value for *array_oid*
- using a query such as :sql:`SELECT 'hstore[]'::regtype::oid`.
- Note that, when passing a dictionary from Python to the database, both
- strings and unicode keys and values are supported. Dictionaries returned
- from the database have keys/values according to the *unicode* parameter.
- The |hstore| contrib module must be already installed in the database
- (executing the ``hstore.sql`` script in your ``contrib`` directory).
- Raise `~psycopg2.ProgrammingError` if the type is not found.
- """
- if oid is None:
- oid = HstoreAdapter.get_oids(conn_or_curs)
- if oid is None or not oid[0]:
- raise psycopg2.ProgrammingError(
- "hstore type not found in the database. "
- "please install it from your 'contrib/hstore.sql' file")
- else:
- array_oid = oid[1]
- oid = oid[0]
- if isinstance(oid, int):
- oid = (oid,)
- if array_oid is not None:
- if isinstance(array_oid, int):
- array_oid = (array_oid,)
- else:
- array_oid = tuple([x for x in array_oid if x])
- # create and register the typecaster
- HSTORE = _ext.new_type(oid, "HSTORE", HstoreAdapter.parse)
- _ext.register_type(HSTORE, not globally and conn_or_curs or None)
- _ext.register_adapter(dict, HstoreAdapter)
- if array_oid:
- HSTOREARRAY = _ext.new_array_type(array_oid, "HSTOREARRAY", HSTORE)
- _ext.register_type(HSTOREARRAY, not globally and conn_or_curs or None)
- class CompositeCaster:
- """Helps conversion of a PostgreSQL composite type into a Python object.
- The class is usually created by the `register_composite()` function.
- You may want to create and register manually instances of the class if
- querying the database at registration time is not desirable (such as when
- using an :ref:`asynchronous connections <async-support>`).
- """
- def __init__(self, name, oid, attrs, array_oid=None, schema=None):
- self.name = name
- self.schema = schema
- self.oid = oid
- self.array_oid = array_oid
- self.attnames = [a[0] for a in attrs]
- self.atttypes = [a[1] for a in attrs]
- self._create_type(name, self.attnames)
- self.typecaster = _ext.new_type((oid,), name, self.parse)
- if array_oid:
- self.array_typecaster = _ext.new_array_type(
- (array_oid,), f"{name}ARRAY", self.typecaster)
- else:
- self.array_typecaster = None
- def parse(self, s, curs):
- if s is None:
- return None
- tokens = self.tokenize(s)
- if len(tokens) != len(self.atttypes):
- raise psycopg2.DataError(
- "expecting %d components for the type %s, %d found instead" %
- (len(self.atttypes), self.name, len(tokens)))
- values = [curs.cast(oid, token)
- for oid, token in zip(self.atttypes, tokens)]
- return self.make(values)
- def make(self, values):
- """Return a new Python object representing the data being casted.
- *values* is the list of attributes, already casted into their Python
- representation.
- You can subclass this method to :ref:`customize the composite cast
- <custom-composite>`.
- """
- return self._ctor(values)
- _re_tokenize = _re.compile(r"""
- \(? ([,)]) # an empty token, representing NULL
- | \(? " ((?: [^"] | "")*) " [,)] # or a quoted string
- | \(? ([^",)]+) [,)] # or an unquoted string
- """, _re.VERBOSE)
- _re_undouble = _re.compile(r'(["\\])\1')
- @classmethod
- def tokenize(self, s):
- rv = []
- for m in self._re_tokenize.finditer(s):
- if m is None:
- raise psycopg2.InterfaceError(f"can't parse type: {s!r}")
- if m.group(1) is not None:
- rv.append(None)
- elif m.group(2) is not None:
- rv.append(self._re_undouble.sub(r"\1", m.group(2)))
- else:
- rv.append(m.group(3))
- return rv
- def _create_type(self, name, attnames):
- self.type = namedtuple(name, attnames)
- self._ctor = self.type._make
- @classmethod
- def _from_db(self, name, conn_or_curs):
- """Return a `CompositeCaster` instance for the type *name*.
- Raise `ProgrammingError` if the type is not found.
- """
- conn, curs = _solve_conn_curs(conn_or_curs)
- # Store the transaction status of the connection to revert it after use
- conn_status = conn.status
- # Use the correct schema
- if '.' in name:
- schema, tname = name.split('.', 1)
- else:
- tname = name
- schema = 'public'
- # column typarray not available before PG 8.3
- typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
- # get the type oid and attributes
- curs.execute("""\
- SELECT t.oid, %s, attname, atttypid
- FROM pg_type t
- JOIN pg_namespace ns ON typnamespace = ns.oid
- JOIN pg_attribute a ON attrelid = typrelid
- WHERE typname = %%s AND nspname = %%s
- AND attnum > 0 AND NOT attisdropped
- ORDER BY attnum;
- """ % typarray, (tname, schema))
- recs = curs.fetchall()
- # revert the status of the connection as before the command
- if (conn_status != _ext.STATUS_IN_TRANSACTION
- and not conn.autocommit):
- conn.rollback()
- if not recs:
- raise psycopg2.ProgrammingError(
- f"PostgreSQL type '{name}' not found")
- type_oid = recs[0][0]
- array_oid = recs[0][1]
- type_attrs = [(r[2], r[3]) for r in recs]
- return self(tname, type_oid, type_attrs,
- array_oid=array_oid, schema=schema)
- def register_composite(name, conn_or_curs, globally=False, factory=None):
- """Register a typecaster to convert a composite type into a tuple.
- :param name: the name of a PostgreSQL composite type, e.g. created using
- the |CREATE TYPE|_ command
- :param conn_or_curs: a connection or cursor used to find the type oid and
- components; the typecaster is registered in a scope limited to this
- object, unless *globally* is set to `!True`
- :param globally: if `!False` (default) register the typecaster only on
- *conn_or_curs*, otherwise register it globally
- :param factory: if specified it should be a `CompositeCaster` subclass: use
- it to :ref:`customize how to cast composite types <custom-composite>`
- :return: the registered `CompositeCaster` or *factory* instance
- responsible for the conversion
- """
- if factory is None:
- factory = CompositeCaster
- caster = factory._from_db(name, conn_or_curs)
- _ext.register_type(caster.typecaster, not globally and conn_or_curs or None)
- if caster.array_typecaster is not None:
- _ext.register_type(
- caster.array_typecaster, not globally and conn_or_curs or None)
- return caster
- def _paginate(seq, page_size):
- """Consume an iterable and return it in chunks.
- Every chunk is at most `page_size`. Never return an empty chunk.
- """
- page = []
- it = iter(seq)
- while True:
- try:
- for i in range(page_size):
- page.append(next(it))
- yield page
- page = []
- except StopIteration:
- if page:
- yield page
- return
- def execute_batch(cur, sql, argslist, page_size=100):
- r"""Execute groups of statements in fewer server roundtrips.
- Execute *sql* several times, against all parameters set (sequences or
- mappings) found in *argslist*.
- The function is semantically similar to
- .. parsed-literal::
- *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ )
- but has a different implementation: Psycopg will join the statements into
- fewer multi-statement commands, each one containing at most *page_size*
- statements, resulting in a reduced number of server roundtrips.
- After the execution of the function the `cursor.rowcount` property will
- **not** contain a total result.
- """
- for page in _paginate(argslist, page_size=page_size):
- sqls = [cur.mogrify(sql, args) for args in page]
- cur.execute(b";".join(sqls))
- def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False):
- '''Execute a statement using :sql:`VALUES` with a sequence of parameters.
- :param cur: the cursor to use to execute the query.
- :param sql: the query to execute. It must contain a single ``%s``
- placeholder, which will be replaced by a `VALUES list`__.
- Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
- :param argslist: sequence of sequences or dictionaries with the arguments
- to send to the query. The type and content must be consistent with
- *template*.
- :param template: the snippet to merge to every item in *argslist* to
- compose the query.
- - If the *argslist* items are sequences it should contain positional
- placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there
- are constants value...).
- - If the *argslist* items are mappings it should contain named
- placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
- If not specified, assume the arguments are sequence and use a simple
- positional template (i.e. ``(%s, %s, ...)``), with the number of
- placeholders sniffed by the first element in *argslist*.
- :param page_size: maximum number of *argslist* items to include in every
- statement. If there are more items the function will execute more than
- one statement.
- :param fetch: if `!True` return the query results into a list (like in a
- `~cursor.fetchall()`). Useful for queries with :sql:`RETURNING`
- clause.
- .. __: https://www.postgresql.org/docs/current/static/queries-values.html
- After the execution of the function the `cursor.rowcount` property will
- **not** contain a total result.
- While :sql:`INSERT` is an obvious candidate for this function it is
- possible to use it with other statements, for example::
- >>> cur.execute(
- ... "create table test (id int primary key, v1 int, v2 int)")
- >>> execute_values(cur,
- ... "INSERT INTO test (id, v1, v2) VALUES %s",
- ... [(1, 2, 3), (4, 5, 6), (7, 8, 9)])
- >>> execute_values(cur,
- ... """UPDATE test SET v1 = data.v1 FROM (VALUES %s) AS data (id, v1)
- ... WHERE test.id = data.id""",
- ... [(1, 20), (4, 50)])
- >>> cur.execute("select * from test order by id")
- >>> cur.fetchall()
- [(1, 20, 3), (4, 50, 6), (7, 8, 9)])
- '''
- from psycopg2.sql import Composable
- if isinstance(sql, Composable):
- sql = sql.as_string(cur)
- # we can't just use sql % vals because vals is bytes: if sql is bytes
- # there will be some decoding error because of stupid codec used, and Py3
- # doesn't implement % on bytes.
- if not isinstance(sql, bytes):
- sql = sql.encode(_ext.encodings[cur.connection.encoding])
- pre, post = _split_sql(sql)
- result = [] if fetch else None
- for page in _paginate(argslist, page_size=page_size):
- if template is None:
- template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
- parts = pre[:]
- for args in page:
- parts.append(cur.mogrify(template, args))
- parts.append(b',')
- parts[-1:] = post
- cur.execute(b''.join(parts))
- if fetch:
- result.extend(cur.fetchall())
- return result
- def _split_sql(sql):
- """Split *sql* on a single ``%s`` placeholder.
- Split on the %s, perform %% replacement and return pre, post lists of
- snippets.
- """
- curr = pre = []
- post = []
- tokens = _re.split(br'(%.)', sql)
- for token in tokens:
- if len(token) != 2 or token[:1] != b'%':
- curr.append(token)
- continue
- if token[1:] == b's':
- if curr is pre:
- curr = post
- else:
- raise ValueError(
- "the query contains more than one '%s' placeholder")
- elif token[1:] == b'%':
- curr.append(b'%')
- else:
- raise ValueError("unsupported format character: '%s'"
- % token[1:].decode('ascii', 'replace'))
- if curr is pre:
- raise ValueError("the query doesn't contain any '%s' placeholder")
- return pre, post
|