First Commit

This commit is contained in:
2025-11-02 22:52:08 +01:00
commit 73fbbf1be2
5821 changed files with 977526 additions and 0 deletions

View File

@@ -0,0 +1,301 @@
# Copyright (c) 2009, 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""
MySQL Connector/Python - MySQL driver written in Python
"""
try:
import _mysql_connector # pylint: disable=F0401
from .connection_cext import CMySQLConnection
except ImportError:
HAVE_CEXT = False
else:
HAVE_CEXT = True
try:
import dns.resolver
import dns.exception
except ImportError:
HAVE_DNSPYTHON = False
else:
HAVE_DNSPYTHON = True
import random
import warnings
from . import version
from .connection import MySQLConnection
from .constants import DEFAULT_CONFIGURATION
from .errors import ( # pylint: disable=W0622
Error, Warning, InterfaceError, DatabaseError,
NotSupportedError, DataError, IntegrityError, ProgrammingError,
OperationalError, InternalError, custom_error_exception, PoolError)
from .constants import FieldFlag, FieldType, CharacterSet, \
RefreshOption, ClientFlag
from .dbapi import (
Date, Time, Timestamp, Binary, DateFromTicks,
TimestampFromTicks, TimeFromTicks,
STRING, BINARY, NUMBER, DATETIME, ROWID,
apilevel, threadsafety, paramstyle)
from .optionfiles import read_option_files
_CONNECTION_POOLS = {}
ERROR_NO_CEXT = "MySQL Connector/Python C Extension not available"
def _get_pooled_connection(**kwargs):
"""Return a pooled MySQL connection"""
# If no pool name specified, generate one
from .pooling import (
MySQLConnectionPool, generate_pool_name,
CONNECTION_POOL_LOCK)
try:
pool_name = kwargs['pool_name']
except KeyError:
pool_name = generate_pool_name(**kwargs)
if 'use_pure' in kwargs:
if not kwargs['use_pure'] and not HAVE_CEXT:
raise ImportError(ERROR_NO_CEXT)
# Setup the pool, ensuring only 1 thread can update at a time
with CONNECTION_POOL_LOCK:
if pool_name not in _CONNECTION_POOLS:
_CONNECTION_POOLS[pool_name] = MySQLConnectionPool(**kwargs)
elif isinstance(_CONNECTION_POOLS[pool_name], MySQLConnectionPool):
# pool_size must be the same
check_size = _CONNECTION_POOLS[pool_name].pool_size
if ('pool_size' in kwargs
and kwargs['pool_size'] != check_size):
raise PoolError("Size can not be changed "
"for active pools.")
# Return pooled connection
try:
return _CONNECTION_POOLS[pool_name].get_connection()
except AttributeError:
raise InterfaceError(
"Failed getting connection from pool '{0}'".format(pool_name))
def _get_failover_connection(**kwargs):
"""Return a MySQL connection and try to failover if needed
An InterfaceError is raise when no MySQL is available. ValueError is
raised when the failover server configuration contains an illegal
connection argument. Supported arguments are user, password, host, port,
unix_socket and database. ValueError is also raised when the failover
argument was not provided.
Returns MySQLConnection instance.
"""
config = kwargs.copy()
try:
failover = config['failover']
except KeyError:
raise ValueError('failover argument not provided')
del config['failover']
support_cnx_args = set(
['user', 'password', 'host', 'port', 'unix_socket',
'database', 'pool_name', 'pool_size', 'priority'])
# First check if we can add all use the configuration
priority_count = 0
for server in failover:
diff = set(server.keys()) - support_cnx_args
if diff:
raise ValueError(
"Unsupported connection argument {0} in failover: {1}".format(
's' if len(diff) > 1 else '',
', '.join(diff)))
if hasattr(server, "priority"):
priority_count += 1
server["priority"] = server.get("priority", 100)
if server["priority"] < 0 or server["priority"] > 100:
raise InterfaceError(
"Priority value should be in the range of 0 to 100, "
"got : {}".format(server["priority"]))
if not isinstance(server["priority"], int):
raise InterfaceError(
"Priority value should be an integer in the range of 0 to "
"100, got : {}".format(server["priority"]))
if 0 < priority_count < len(failover):
raise ProgrammingError("You must either assign no priority to any "
"of the routers or give a priority for "
"every router")
failover.sort(key=lambda x: x['priority'], reverse=True)
server_directory = {}
server_priority_list = []
for server in failover:
if server["priority"] not in server_directory:
server_directory[server["priority"]] = [server]
server_priority_list.append(server["priority"])
else:
server_directory[server["priority"]].append(server)
for priority in server_priority_list:
failover_list = server_directory[priority]
for _ in range(len(failover_list)):
last = len(failover_list) - 1
index = random.randint(0, last)
server = failover_list.pop(index)
new_config = config.copy()
new_config.update(server)
new_config.pop('priority', None)
try:
return connect(**new_config)
except Error:
# If we failed to connect, we try the next server
pass
raise InterfaceError("Unable to connect to any of the target hosts")
def connect(*args, **kwargs):
"""Create or get a MySQL connection object
In its simpliest form, Connect() will open a connection to a
MySQL server and return a MySQLConnection object.
When any connection pooling arguments are given, for example pool_name
or pool_size, a pool is created or a previously one is used to return
a PooledMySQLConnection.
Returns MySQLConnection or PooledMySQLConnection.
"""
# DNS SRV
dns_srv = kwargs.pop('dns_srv') if 'dns_srv' in kwargs else False
if not isinstance(dns_srv, bool):
raise InterfaceError("The value of 'dns-srv' must be a boolean")
if dns_srv:
if not HAVE_DNSPYTHON:
raise InterfaceError('MySQL host configuration requested DNS '
'SRV. This requires the Python dnspython '
'module. Please refer to documentation')
if 'unix_socket' in kwargs:
raise InterfaceError('Using Unix domain sockets with DNS SRV '
'lookup is not allowed')
if 'port' in kwargs:
raise InterfaceError('Specifying a port number with DNS SRV '
'lookup is not allowed')
if 'failover' in kwargs:
raise InterfaceError('Specifying multiple hostnames with DNS '
'SRV look up is not allowed')
if 'host' not in kwargs:
kwargs['host'] = DEFAULT_CONFIGURATION['host']
try:
srv_records = dns.resolver.query(kwargs['host'], 'SRV')
except dns.exception.DNSException:
raise InterfaceError("Unable to locate any hosts for '{0}'"
"".format(kwargs['host']))
failover = []
for srv in srv_records:
failover.append({
'host': srv.target.to_text(omit_final_dot=True),
'port': srv.port,
'priority': srv.priority,
'weight': srv.weight
})
failover.sort(key=lambda x: (x['priority'], -x['weight']))
kwargs['failover'] = [{'host': srv['host'],
'port': srv['port']} for srv in failover]
# Option files
if 'read_default_file' in kwargs:
kwargs['option_files'] = kwargs['read_default_file']
kwargs.pop('read_default_file')
if 'option_files' in kwargs:
new_config = read_option_files(**kwargs)
return connect(**new_config)
# Failover
if 'failover' in kwargs:
return _get_failover_connection(**kwargs)
# Pooled connections
try:
from .constants import CNX_POOL_ARGS
if any([key in kwargs for key in CNX_POOL_ARGS]):
return _get_pooled_connection(**kwargs)
except NameError:
# No pooling
pass
# Use C Extension by default
use_pure = kwargs.get('use_pure', False)
if 'use_pure' in kwargs:
del kwargs['use_pure'] # Remove 'use_pure' from kwargs
if not use_pure and not HAVE_CEXT:
raise ImportError(ERROR_NO_CEXT)
if HAVE_CEXT and not use_pure:
return CMySQLConnection(*args, **kwargs)
return MySQLConnection(*args, **kwargs)
Connect = connect # pylint: disable=C0103
__version_info__ = version.VERSION
__version__ = version.VERSION_TEXT
__all__ = [
'MySQLConnection', 'Connect', 'custom_error_exception',
# Some useful constants
'FieldType', 'FieldFlag', 'ClientFlag', 'CharacterSet', 'RefreshOption',
'HAVE_CEXT',
# Error handling
'Error', 'Warning',
'InterfaceError', 'DatabaseError',
'NotSupportedError', 'DataError', 'IntegrityError', 'ProgrammingError',
'OperationalError', 'InternalError',
# DBAPI PEP 249 required exports
'connect', 'apilevel', 'threadsafety', 'paramstyle',
'Date', 'Time', 'Timestamp', 'Binary',
'DateFromTicks', 'DateFromTicks', 'TimestampFromTicks', 'TimeFromTicks',
'STRING', 'BINARY', 'NUMBER',
'DATETIME', 'ROWID',
# C Extension
'CMySQLConnection',
]

View File

@@ -0,0 +1,350 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2013, 2019, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
# This file was auto-generated.
_GENERATED_ON = '2019-04-29'
_MYSQL_VERSION = (8, 0, 17)
"""This module contains the MySQL Server Character Sets"""
MYSQL_CHARACTER_SETS = [
# (character set name, collation, default)
None,
("big5", "big5_chinese_ci", True), # 1
("latin2", "latin2_czech_cs", False), # 2
("dec8", "dec8_swedish_ci", True), # 3
("cp850", "cp850_general_ci", True), # 4
("latin1", "latin1_german1_ci", False), # 5
("hp8", "hp8_english_ci", True), # 6
("koi8r", "koi8r_general_ci", True), # 7
("latin1", "latin1_swedish_ci", True), # 8
("latin2", "latin2_general_ci", True), # 9
("swe7", "swe7_swedish_ci", True), # 10
("ascii", "ascii_general_ci", True), # 11
("ujis", "ujis_japanese_ci", True), # 12
("sjis", "sjis_japanese_ci", True), # 13
("cp1251", "cp1251_bulgarian_ci", False), # 14
("latin1", "latin1_danish_ci", False), # 15
("hebrew", "hebrew_general_ci", True), # 16
None,
("tis620", "tis620_thai_ci", True), # 18
("euckr", "euckr_korean_ci", True), # 19
("latin7", "latin7_estonian_cs", False), # 20
("latin2", "latin2_hungarian_ci", False), # 21
("koi8u", "koi8u_general_ci", True), # 22
("cp1251", "cp1251_ukrainian_ci", False), # 23
("gb2312", "gb2312_chinese_ci", True), # 24
("greek", "greek_general_ci", True), # 25
("cp1250", "cp1250_general_ci", True), # 26
("latin2", "latin2_croatian_ci", False), # 27
("gbk", "gbk_chinese_ci", True), # 28
("cp1257", "cp1257_lithuanian_ci", False), # 29
("latin5", "latin5_turkish_ci", True), # 30
("latin1", "latin1_german2_ci", False), # 31
("armscii8", "armscii8_general_ci", True), # 32
("utf8", "utf8_general_ci", True), # 33
("cp1250", "cp1250_czech_cs", False), # 34
("ucs2", "ucs2_general_ci", True), # 35
("cp866", "cp866_general_ci", True), # 36
("keybcs2", "keybcs2_general_ci", True), # 37
("macce", "macce_general_ci", True), # 38
("macroman", "macroman_general_ci", True), # 39
("cp852", "cp852_general_ci", True), # 40
("latin7", "latin7_general_ci", True), # 41
("latin7", "latin7_general_cs", False), # 42
("macce", "macce_bin", False), # 43
("cp1250", "cp1250_croatian_ci", False), # 44
("utf8mb4", "utf8mb4_general_ci", False), # 45
("utf8mb4", "utf8mb4_bin", False), # 46
("latin1", "latin1_bin", False), # 47
("latin1", "latin1_general_ci", False), # 48
("latin1", "latin1_general_cs", False), # 49
("cp1251", "cp1251_bin", False), # 50
("cp1251", "cp1251_general_ci", True), # 51
("cp1251", "cp1251_general_cs", False), # 52
("macroman", "macroman_bin", False), # 53
("utf16", "utf16_general_ci", True), # 54
("utf16", "utf16_bin", False), # 55
("utf16le", "utf16le_general_ci", True), # 56
("cp1256", "cp1256_general_ci", True), # 57
("cp1257", "cp1257_bin", False), # 58
("cp1257", "cp1257_general_ci", True), # 59
("utf32", "utf32_general_ci", True), # 60
("utf32", "utf32_bin", False), # 61
("utf16le", "utf16le_bin", False), # 62
("binary", "binary", True), # 63
("armscii8", "armscii8_bin", False), # 64
("ascii", "ascii_bin", False), # 65
("cp1250", "cp1250_bin", False), # 66
("cp1256", "cp1256_bin", False), # 67
("cp866", "cp866_bin", False), # 68
("dec8", "dec8_bin", False), # 69
("greek", "greek_bin", False), # 70
("hebrew", "hebrew_bin", False), # 71
("hp8", "hp8_bin", False), # 72
("keybcs2", "keybcs2_bin", False), # 73
("koi8r", "koi8r_bin", False), # 74
("koi8u", "koi8u_bin", False), # 75
("utf8", "utf8_tolower_ci", False), # 76
("latin2", "latin2_bin", False), # 77
("latin5", "latin5_bin", False), # 78
("latin7", "latin7_bin", False), # 79
("cp850", "cp850_bin", False), # 80
("cp852", "cp852_bin", False), # 81
("swe7", "swe7_bin", False), # 82
("utf8", "utf8_bin", False), # 83
("big5", "big5_bin", False), # 84
("euckr", "euckr_bin", False), # 85
("gb2312", "gb2312_bin", False), # 86
("gbk", "gbk_bin", False), # 87
("sjis", "sjis_bin", False), # 88
("tis620", "tis620_bin", False), # 89
("ucs2", "ucs2_bin", False), # 90
("ujis", "ujis_bin", False), # 91
("geostd8", "geostd8_general_ci", True), # 92
("geostd8", "geostd8_bin", False), # 93
("latin1", "latin1_spanish_ci", False), # 94
("cp932", "cp932_japanese_ci", True), # 95
("cp932", "cp932_bin", False), # 96
("eucjpms", "eucjpms_japanese_ci", True), # 97
("eucjpms", "eucjpms_bin", False), # 98
("cp1250", "cp1250_polish_ci", False), # 99
None,
("utf16", "utf16_unicode_ci", False), # 101
("utf16", "utf16_icelandic_ci", False), # 102
("utf16", "utf16_latvian_ci", False), # 103
("utf16", "utf16_romanian_ci", False), # 104
("utf16", "utf16_slovenian_ci", False), # 105
("utf16", "utf16_polish_ci", False), # 106
("utf16", "utf16_estonian_ci", False), # 107
("utf16", "utf16_spanish_ci", False), # 108
("utf16", "utf16_swedish_ci", False), # 109
("utf16", "utf16_turkish_ci", False), # 110
("utf16", "utf16_czech_ci", False), # 111
("utf16", "utf16_danish_ci", False), # 112
("utf16", "utf16_lithuanian_ci", False), # 113
("utf16", "utf16_slovak_ci", False), # 114
("utf16", "utf16_spanish2_ci", False), # 115
("utf16", "utf16_roman_ci", False), # 116
("utf16", "utf16_persian_ci", False), # 117
("utf16", "utf16_esperanto_ci", False), # 118
("utf16", "utf16_hungarian_ci", False), # 119
("utf16", "utf16_sinhala_ci", False), # 120
("utf16", "utf16_german2_ci", False), # 121
("utf16", "utf16_croatian_ci", False), # 122
("utf16", "utf16_unicode_520_ci", False), # 123
("utf16", "utf16_vietnamese_ci", False), # 124
None,
None,
None,
("ucs2", "ucs2_unicode_ci", False), # 128
("ucs2", "ucs2_icelandic_ci", False), # 129
("ucs2", "ucs2_latvian_ci", False), # 130
("ucs2", "ucs2_romanian_ci", False), # 131
("ucs2", "ucs2_slovenian_ci", False), # 132
("ucs2", "ucs2_polish_ci", False), # 133
("ucs2", "ucs2_estonian_ci", False), # 134
("ucs2", "ucs2_spanish_ci", False), # 135
("ucs2", "ucs2_swedish_ci", False), # 136
("ucs2", "ucs2_turkish_ci", False), # 137
("ucs2", "ucs2_czech_ci", False), # 138
("ucs2", "ucs2_danish_ci", False), # 139
("ucs2", "ucs2_lithuanian_ci", False), # 140
("ucs2", "ucs2_slovak_ci", False), # 141
("ucs2", "ucs2_spanish2_ci", False), # 142
("ucs2", "ucs2_roman_ci", False), # 143
("ucs2", "ucs2_persian_ci", False), # 144
("ucs2", "ucs2_esperanto_ci", False), # 145
("ucs2", "ucs2_hungarian_ci", False), # 146
("ucs2", "ucs2_sinhala_ci", False), # 147
("ucs2", "ucs2_german2_ci", False), # 148
("ucs2", "ucs2_croatian_ci", False), # 149
("ucs2", "ucs2_unicode_520_ci", False), # 150
("ucs2", "ucs2_vietnamese_ci", False), # 151
None,
None,
None,
None,
None,
None,
None,
("ucs2", "ucs2_general_mysql500_ci", False), # 159
("utf32", "utf32_unicode_ci", False), # 160
("utf32", "utf32_icelandic_ci", False), # 161
("utf32", "utf32_latvian_ci", False), # 162
("utf32", "utf32_romanian_ci", False), # 163
("utf32", "utf32_slovenian_ci", False), # 164
("utf32", "utf32_polish_ci", False), # 165
("utf32", "utf32_estonian_ci", False), # 166
("utf32", "utf32_spanish_ci", False), # 167
("utf32", "utf32_swedish_ci", False), # 168
("utf32", "utf32_turkish_ci", False), # 169
("utf32", "utf32_czech_ci", False), # 170
("utf32", "utf32_danish_ci", False), # 171
("utf32", "utf32_lithuanian_ci", False), # 172
("utf32", "utf32_slovak_ci", False), # 173
("utf32", "utf32_spanish2_ci", False), # 174
("utf32", "utf32_roman_ci", False), # 175
("utf32", "utf32_persian_ci", False), # 176
("utf32", "utf32_esperanto_ci", False), # 177
("utf32", "utf32_hungarian_ci", False), # 178
("utf32", "utf32_sinhala_ci", False), # 179
("utf32", "utf32_german2_ci", False), # 180
("utf32", "utf32_croatian_ci", False), # 181
("utf32", "utf32_unicode_520_ci", False), # 182
("utf32", "utf32_vietnamese_ci", False), # 183
None,
None,
None,
None,
None,
None,
None,
None,
("utf8", "utf8_unicode_ci", False), # 192
("utf8", "utf8_icelandic_ci", False), # 193
("utf8", "utf8_latvian_ci", False), # 194
("utf8", "utf8_romanian_ci", False), # 195
("utf8", "utf8_slovenian_ci", False), # 196
("utf8", "utf8_polish_ci", False), # 197
("utf8", "utf8_estonian_ci", False), # 198
("utf8", "utf8_spanish_ci", False), # 199
("utf8", "utf8_swedish_ci", False), # 200
("utf8", "utf8_turkish_ci", False), # 201
("utf8", "utf8_czech_ci", False), # 202
("utf8", "utf8_danish_ci", False), # 203
("utf8", "utf8_lithuanian_ci", False), # 204
("utf8", "utf8_slovak_ci", False), # 205
("utf8", "utf8_spanish2_ci", False), # 206
("utf8", "utf8_roman_ci", False), # 207
("utf8", "utf8_persian_ci", False), # 208
("utf8", "utf8_esperanto_ci", False), # 209
("utf8", "utf8_hungarian_ci", False), # 210
("utf8", "utf8_sinhala_ci", False), # 211
("utf8", "utf8_german2_ci", False), # 212
("utf8", "utf8_croatian_ci", False), # 213
("utf8", "utf8_unicode_520_ci", False), # 214
("utf8", "utf8_vietnamese_ci", False), # 215
None,
None,
None,
None,
None,
None,
None,
("utf8", "utf8_general_mysql500_ci", False), # 223
("utf8mb4", "utf8mb4_unicode_ci", False), # 224
("utf8mb4", "utf8mb4_icelandic_ci", False), # 225
("utf8mb4", "utf8mb4_latvian_ci", False), # 226
("utf8mb4", "utf8mb4_romanian_ci", False), # 227
("utf8mb4", "utf8mb4_slovenian_ci", False), # 228
("utf8mb4", "utf8mb4_polish_ci", False), # 229
("utf8mb4", "utf8mb4_estonian_ci", False), # 230
("utf8mb4", "utf8mb4_spanish_ci", False), # 231
("utf8mb4", "utf8mb4_swedish_ci", False), # 232
("utf8mb4", "utf8mb4_turkish_ci", False), # 233
("utf8mb4", "utf8mb4_czech_ci", False), # 234
("utf8mb4", "utf8mb4_danish_ci", False), # 235
("utf8mb4", "utf8mb4_lithuanian_ci", False), # 236
("utf8mb4", "utf8mb4_slovak_ci", False), # 237
("utf8mb4", "utf8mb4_spanish2_ci", False), # 238
("utf8mb4", "utf8mb4_roman_ci", False), # 239
("utf8mb4", "utf8mb4_persian_ci", False), # 240
("utf8mb4", "utf8mb4_esperanto_ci", False), # 241
("utf8mb4", "utf8mb4_hungarian_ci", False), # 242
("utf8mb4", "utf8mb4_sinhala_ci", False), # 243
("utf8mb4", "utf8mb4_german2_ci", False), # 244
("utf8mb4", "utf8mb4_croatian_ci", False), # 245
("utf8mb4", "utf8mb4_unicode_520_ci", False), # 246
("utf8mb4", "utf8mb4_vietnamese_ci", False), # 247
("gb18030", "gb18030_chinese_ci", True), # 248
("gb18030", "gb18030_bin", False), # 249
("gb18030", "gb18030_unicode_520_ci", False), # 250
None,
None,
None,
None,
("utf8mb4", "utf8mb4_0900_ai_ci", True), # 255
("utf8mb4", "utf8mb4_de_pb_0900_ai_ci", False), # 256
("utf8mb4", "utf8mb4_is_0900_ai_ci", False), # 257
("utf8mb4", "utf8mb4_lv_0900_ai_ci", False), # 258
("utf8mb4", "utf8mb4_ro_0900_ai_ci", False), # 259
("utf8mb4", "utf8mb4_sl_0900_ai_ci", False), # 260
("utf8mb4", "utf8mb4_pl_0900_ai_ci", False), # 261
("utf8mb4", "utf8mb4_et_0900_ai_ci", False), # 262
("utf8mb4", "utf8mb4_es_0900_ai_ci", False), # 263
("utf8mb4", "utf8mb4_sv_0900_ai_ci", False), # 264
("utf8mb4", "utf8mb4_tr_0900_ai_ci", False), # 265
("utf8mb4", "utf8mb4_cs_0900_ai_ci", False), # 266
("utf8mb4", "utf8mb4_da_0900_ai_ci", False), # 267
("utf8mb4", "utf8mb4_lt_0900_ai_ci", False), # 268
("utf8mb4", "utf8mb4_sk_0900_ai_ci", False), # 269
("utf8mb4", "utf8mb4_es_trad_0900_ai_ci", False), # 270
("utf8mb4", "utf8mb4_la_0900_ai_ci", False), # 271
None,
("utf8mb4", "utf8mb4_eo_0900_ai_ci", False), # 273
("utf8mb4", "utf8mb4_hu_0900_ai_ci", False), # 274
("utf8mb4", "utf8mb4_hr_0900_ai_ci", False), # 275
None,
("utf8mb4", "utf8mb4_vi_0900_ai_ci", False), # 277
("utf8mb4", "utf8mb4_0900_as_cs", False), # 278
("utf8mb4", "utf8mb4_de_pb_0900_as_cs", False), # 279
("utf8mb4", "utf8mb4_is_0900_as_cs", False), # 280
("utf8mb4", "utf8mb4_lv_0900_as_cs", False), # 281
("utf8mb4", "utf8mb4_ro_0900_as_cs", False), # 282
("utf8mb4", "utf8mb4_sl_0900_as_cs", False), # 283
("utf8mb4", "utf8mb4_pl_0900_as_cs", False), # 284
("utf8mb4", "utf8mb4_et_0900_as_cs", False), # 285
("utf8mb4", "utf8mb4_es_0900_as_cs", False), # 286
("utf8mb4", "utf8mb4_sv_0900_as_cs", False), # 287
("utf8mb4", "utf8mb4_tr_0900_as_cs", False), # 288
("utf8mb4", "utf8mb4_cs_0900_as_cs", False), # 289
("utf8mb4", "utf8mb4_da_0900_as_cs", False), # 290
("utf8mb4", "utf8mb4_lt_0900_as_cs", False), # 291
("utf8mb4", "utf8mb4_sk_0900_as_cs", False), # 292
("utf8mb4", "utf8mb4_es_trad_0900_as_cs", False), # 293
("utf8mb4", "utf8mb4_la_0900_as_cs", False), # 294
None,
("utf8mb4", "utf8mb4_eo_0900_as_cs", False), # 296
("utf8mb4", "utf8mb4_hu_0900_as_cs", False), # 297
("utf8mb4", "utf8mb4_hr_0900_as_cs", False), # 298
None,
("utf8mb4", "utf8mb4_vi_0900_as_cs", False), # 300
None,
None,
("utf8mb4", "utf8mb4_ja_0900_as_cs", False), # 303
("utf8mb4", "utf8mb4_ja_0900_as_cs_ks", False), # 304
("utf8mb4", "utf8mb4_0900_as_ci", False), # 305
("utf8mb4", "utf8mb4_ru_0900_ai_ci", False), # 306
("utf8mb4", "utf8mb4_ru_0900_as_cs", False), # 307
("utf8mb4", "utf8mb4_zh_0900_as_cs", False), # 308
("utf8mb4", "utf8mb4_0900_bin", False), # 309
]

View File

@@ -0,0 +1,840 @@
# Copyright (c) 2014, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Connection class using the C Extension
"""
# Detection of abstract methods in pylint is not working correctly
#pylint: disable=W0223
import os
import socket
import sysconfig
from . import errors, version
from .constants import (
CharacterSet, FieldFlag, ServerFlag, ShutdownType, ClientFlag
)
from .abstracts import MySQLConnectionAbstract, MySQLCursorAbstract
from .protocol import MySQLProtocol
HAVE_CMYSQL = False
# pylint: disable=F0401,C0413
try:
import _mysql_connector
from .cursor_cext import (
CMySQLCursor, CMySQLCursorRaw,
CMySQLCursorBuffered, CMySQLCursorBufferedRaw, CMySQLCursorPrepared,
CMySQLCursorDict, CMySQLCursorBufferedDict, CMySQLCursorNamedTuple,
CMySQLCursorBufferedNamedTuple)
from _mysql_connector import MySQLInterfaceError # pylint: disable=F0401
except ImportError as exc:
raise ImportError(
"MySQL Connector/Python C Extension not available ({0})".format(
str(exc)
))
else:
HAVE_CMYSQL = True
# pylint: enable=F0401,C0413
class CMySQLConnection(MySQLConnectionAbstract):
"""Class initiating a MySQL Connection using Connector/C"""
def __init__(self, **kwargs):
"""Initialization"""
if not HAVE_CMYSQL:
raise RuntimeError(
"MySQL Connector/Python C Extension not available")
self._cmysql = None
self._columns = []
self._plugin_dir = os.path.join(
os.path.dirname(os.path.abspath(_mysql_connector.__file__)),
"mysql", "vendor", "plugin"
)
self.converter = None
super(CMySQLConnection, self).__init__(**kwargs)
if kwargs:
self.connect(**kwargs)
def _add_default_conn_attrs(self):
"""Add default connection attributes"""
license_chunks = version.LICENSE.split(" ")
if license_chunks[0] == "GPLv2":
client_license = "GPL-2.0"
else:
client_license = "Commercial"
self._conn_attrs.update({
"_connector_name": "mysql-connector-python",
"_connector_license": client_license,
"_connector_version": ".".join(
[str(x) for x in version.VERSION[0:3]]),
"_source_host": socket.gethostname()
})
def _do_handshake(self):
"""Gather information of the MySQL server before authentication"""
self._handshake = {
'protocol': self._cmysql.get_proto_info(),
'server_version_original': self._cmysql.get_server_info(),
'server_threadid': self._cmysql.thread_id(),
'charset': None,
'server_status': None,
'auth_plugin': None,
'auth_data': None,
'capabilities': self._cmysql.st_server_capabilities(),
}
self._server_version = self._check_server_version(
self._handshake['server_version_original']
)
@property
def _server_status(self):
"""Returns the server status attribute of MYSQL structure"""
return self._cmysql.st_server_status()
def set_allow_local_infile_in_path(self, path):
"""set local_infile_in_path
Set allow_local_infile_in_path.
"""
if self._cmysql:
self._cmysql.set_load_data_local_infile_option(path)
def set_unicode(self, value=True):
"""Toggle unicode mode
Set whether we return string fields as unicode or not.
Default is True.
"""
self._use_unicode = value
if self._cmysql:
self._cmysql.use_unicode(value)
if self.converter:
self.converter.set_unicode(value)
@property
def autocommit(self):
"""Get whether autocommit is on or off"""
value = self.info_query("SELECT @@session.autocommit")[0]
return True if value == 1 else False
@autocommit.setter
def autocommit(self, value): # pylint: disable=W0221
"""Toggle autocommit"""
try:
self._cmysql.autocommit(value)
self._autocommit = value
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
@property
def database(self):
"""Get the current database"""
return self.info_query("SELECT DATABASE()")[0]
@database.setter
def database(self, value): # pylint: disable=W0221
"""Set the current database"""
self._cmysql.select_db(value)
@property
def in_transaction(self):
"""MySQL session has started a transaction"""
return self._server_status & ServerFlag.STATUS_IN_TRANS
def _open_connection(self):
charset_name = CharacterSet.get_info(self._charset_id)[0]
self._cmysql = _mysql_connector.MySQL( # pylint: disable=E1101,I1101
buffered=self._buffered,
raw=self._raw,
charset_name=charset_name,
connection_timeout=(self._connection_timeout or 0),
use_unicode=self._use_unicode,
auth_plugin=self._auth_plugin,
plugin_dir=self._plugin_dir)
if not self.isset_client_flag(ClientFlag.CONNECT_ARGS):
self._conn_attrs = {}
cnx_kwargs = {
'host': self._host,
'user': self._user,
'password': self._password,
'password1': self._password1,
'password2': self._password2,
'password3': self._password3,
'database': self._database,
'port': self._port,
'client_flags': self._client_flags,
'unix_socket': self._unix_socket,
'compress': self.isset_client_flag(ClientFlag.COMPRESS),
'ssl_disabled': True,
"conn_attrs": self._conn_attrs,
"local_infile": self._allow_local_infile,
"load_data_local_dir": self._allow_local_infile_in_path,
"oci_config_file": self._oci_config_file,
}
tls_versions = self._ssl.get('tls_versions')
if tls_versions is not None:
tls_versions.sort(reverse=True)
tls_versions = ",".join(tls_versions)
if self._ssl.get('tls_ciphersuites') is not None:
ssl_ciphersuites = self._ssl.get('tls_ciphersuites')[0]
tls_ciphersuites = self._ssl.get('tls_ciphersuites')[1]
else:
ssl_ciphersuites = None
tls_ciphersuites = None
if tls_versions is not None and "TLSv1.3" in tls_versions and \
not tls_ciphersuites:
tls_ciphersuites = "TLS_AES_256_GCM_SHA384"
if not self._ssl_disabled:
cnx_kwargs.update({
'ssl_ca': self._ssl.get('ca'),
'ssl_cert': self._ssl.get('cert'),
'ssl_key': self._ssl.get('key'),
'ssl_cipher_suites': ssl_ciphersuites,
'tls_versions': tls_versions,
'tls_cipher_suites': tls_ciphersuites,
'ssl_verify_cert': self._ssl.get('verify_cert') or False,
'ssl_verify_identity':
self._ssl.get('verify_identity') or False,
'ssl_disabled': self._ssl_disabled
})
try:
self._cmysql.connect(**cnx_kwargs)
self._cmysql.converter_str_fallback = self._converter_str_fallback
if self.converter:
self.converter.str_fallback = self._converter_str_fallback
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
self._do_handshake()
def close(self):
"""Disconnect from the MySQL server"""
if self._cmysql:
try:
self.free_result()
self._cmysql.close()
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
disconnect = close
def is_closed(self):
"""Return True if the connection to MySQL Server is closed."""
return not self._cmysql.connected()
def is_connected(self):
"""Reports whether the connection to MySQL Server is available"""
if self._cmysql:
return self._cmysql.ping()
return False
def ping(self, reconnect=False, attempts=1, delay=0):
"""Check availability of the MySQL server
When reconnect is set to True, one or more attempts are made to try
to reconnect to the MySQL server using the reconnect()-method.
delay is the number of seconds to wait between each retry.
When the connection is not available, an InterfaceError is raised. Use
the is_connected()-method if you just want to check the connection
without raising an error.
Raises InterfaceError on errors.
"""
errmsg = "Connection to MySQL is not available"
try:
connected = self._cmysql.ping()
except AttributeError:
pass # Raise or reconnect later
else:
if connected:
return
if reconnect:
self.reconnect(attempts=attempts, delay=delay)
else:
raise errors.InterfaceError(errmsg)
def set_character_set_name(self, charset):
"""Sets the default character set name for current connection.
"""
self._cmysql.set_character_set(charset)
def info_query(self, query):
"""Send a query which only returns 1 row"""
self._cmysql.query(query)
first_row = ()
if self._cmysql.have_result_set:
first_row = self._cmysql.fetch_row()
if self._cmysql.fetch_row():
self._cmysql.free_result()
raise errors.InterfaceError(
"Query should not return more than 1 row")
self._cmysql.free_result()
return first_row
@property
def connection_id(self):
"""MySQL connection ID"""
try:
return self._cmysql.thread_id()
except MySQLInterfaceError:
pass # Just return None
return None
def get_rows(self, count=None, binary=False, columns=None, raw=None,
prep_stmt=None):
"""Get all or a subset of rows returned by the MySQL server"""
unread_result = prep_stmt.have_result_set if prep_stmt \
else self.unread_result
if not (self._cmysql and unread_result):
raise errors.InternalError("No result set available")
if raw is None:
raw = self._raw
rows = []
if count is not None and count <= 0:
raise AttributeError("count should be 1 or higher, or None")
counter = 0
try:
fetch_row = (
prep_stmt.fetch_row if prep_stmt
else self._cmysql.fetch_row
)
if self.converter:
# When using a converter class, the C extension should not
# convert the values. This can be accomplished by setting
# the raw option to True.
self._cmysql.raw(True)
row = fetch_row()
while row:
if not self._raw and self.converter:
row = list(row)
for i, _ in enumerate(row):
if not raw:
row[i] = self.converter.to_python(self._columns[i],
row[i])
row = tuple(row)
rows.append(row)
counter += 1
if count and counter == count:
break
row = fetch_row()
if not row:
_eof = self.fetch_eof_columns(prep_stmt)['eof']
if prep_stmt:
prep_stmt.free_result()
self._unread_result = False
else:
self.free_result()
else:
_eof = None
except MySQLInterfaceError as exc:
if prep_stmt:
prep_stmt.free_result()
raise errors.InterfaceError(str(exc))
else:
self.free_result()
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
return rows, _eof
def get_row(self, binary=False, columns=None, raw=None, prep_stmt=None):
"""Get the next rows returned by the MySQL server"""
try:
rows, eof = self.get_rows(count=1, binary=binary, columns=columns,
raw=raw, prep_stmt=prep_stmt)
if rows:
return (rows[0], eof)
return (None, eof)
except IndexError:
# No row available
return (None, None)
def next_result(self):
"""Reads the next result"""
if self._cmysql:
self._cmysql.consume_result()
return self._cmysql.next_result()
return None
def free_result(self):
"""Frees the result"""
if self._cmysql:
self._cmysql.free_result()
def commit(self):
"""Commit current transaction"""
if self._cmysql:
self._cmysql.commit()
def rollback(self):
"""Rollback current transaction"""
if self._cmysql:
self._cmysql.consume_result()
self._cmysql.rollback()
def cmd_init_db(self, database):
"""Change the current database"""
try:
self._cmysql.select_db(database)
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
def fetch_eof_columns(self, prep_stmt=None):
"""Fetch EOF and column information"""
have_result_set = prep_stmt.have_result_set if prep_stmt \
else self._cmysql.have_result_set
if not have_result_set:
raise errors.InterfaceError("No result set")
fields = prep_stmt.fetch_fields() if prep_stmt \
else self._cmysql.fetch_fields()
self._columns = []
for col in fields:
self._columns.append((
col[4],
int(col[8]),
None,
None,
None,
None,
~int(col[9]) & FieldFlag.NOT_NULL,
int(col[9]),
int(col[6]),
))
return {
'eof': {
'status_flag': self._server_status,
'warning_count': self._cmysql.st_warning_count(),
},
'columns': self._columns,
}
def fetch_eof_status(self):
"""Fetch EOF and status information"""
if self._cmysql:
return {
'warning_count': self._cmysql.st_warning_count(),
'field_count': self._cmysql.st_field_count(),
'insert_id': self._cmysql.insert_id(),
'affected_rows': self._cmysql.affected_rows(),
'server_status': self._server_status,
}
return None
def cmd_stmt_prepare(self, statement):
"""Prepares the SQL statement"""
if not self._cmysql:
raise errors.OperationalError("MySQL Connection not available")
try:
stmt = self._cmysql.stmt_prepare(statement)
stmt.converter_str_fallback = self._converter_str_fallback
return stmt
except MySQLInterfaceError as err:
raise errors.InterfaceError(str(err))
# pylint: disable=W0221
def cmd_stmt_execute(self, prep_stmt, *args):
"""Executes the prepared statement"""
try:
prep_stmt.stmt_execute(*args)
except MySQLInterfaceError as err:
raise errors.InterfaceError(str(err))
self._columns = []
if not prep_stmt.have_result_set:
# No result
self._unread_result = False
return self.fetch_eof_status()
self._unread_result = True
return self.fetch_eof_columns(prep_stmt)
def cmd_stmt_close(self, prep_stmt):
"""Closes the prepared statement"""
if self._unread_result:
raise errors.InternalError("Unread result found")
prep_stmt.stmt_close()
def cmd_stmt_reset(self, prep_stmt):
"""Resets the prepared statement"""
if self._unread_result:
raise errors.InternalError("Unread result found")
prep_stmt.stmt_reset()
# pylint: enable=W0221
def cmd_query(self, query, raw=None, buffered=False, raw_as_string=False):
"""Send a query to the MySQL server"""
self.handle_unread_result()
if raw is None:
raw = self._raw
try:
if not isinstance(query, bytes):
query = query.encode('utf-8')
self._cmysql.query(query,
raw=raw, buffered=buffered,
raw_as_string=raw_as_string,
query_attrs=self._query_attrs)
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(exc.errno, msg=exc.msg,
sqlstate=exc.sqlstate)
except AttributeError:
if self._unix_socket:
addr = self._unix_socket
else:
addr = self._host + ':' + str(self._port)
raise errors.OperationalError(
errno=2055, values=(addr, 'Connection not available.'))
self._columns = []
if not self._cmysql.have_result_set:
# No result
return self.fetch_eof_status()
return self.fetch_eof_columns()
_execute_query = cmd_query
def cursor(self, buffered=None, raw=None, prepared=None, cursor_class=None,
dictionary=None, named_tuple=None):
"""Instantiates and returns a cursor using C Extension
By default, CMySQLCursor is returned. Depending on the options
while connecting, a buffered and/or raw cursor is instantiated
instead. Also depending upon the cursor options, rows can be
returned as dictionary or named tuple.
Dictionary and namedtuple based cursors are available with buffered
output but not raw.
It is possible to also give a custom cursor through the
cursor_class parameter, but it needs to be a subclass of
mysql.connector.cursor_cext.CMySQLCursor.
Raises ProgrammingError when cursor_class is not a subclass of
CursorBase. Raises ValueError when cursor is not available.
Returns instance of CMySQLCursor or subclass.
:param buffered: Return a buffering cursor
:param raw: Return a raw cursor
:param prepared: Return a cursor which uses prepared statements
:param cursor_class: Use a custom cursor class
:param dictionary: Rows are returned as dictionary
:param named_tuple: Rows are returned as named tuple
:return: Subclass of CMySQLCursor
:rtype: CMySQLCursor or subclass
"""
self.handle_unread_result(prepared)
if not self.is_connected():
raise errors.OperationalError("MySQL Connection not available.")
if cursor_class is not None:
if not issubclass(cursor_class, MySQLCursorAbstract):
raise errors.ProgrammingError(
"Cursor class needs be to subclass"
" of cursor_cext.CMySQLCursor")
return (cursor_class)(self)
buffered = buffered or self._buffered
raw = raw or self._raw
cursor_type = 0
if buffered is True:
cursor_type |= 1
if raw is True:
cursor_type |= 2
if dictionary is True:
cursor_type |= 4
if named_tuple is True:
cursor_type |= 8
if prepared is True:
cursor_type |= 16
types = {
0: CMySQLCursor, # 0
1: CMySQLCursorBuffered,
2: CMySQLCursorRaw,
3: CMySQLCursorBufferedRaw,
4: CMySQLCursorDict,
5: CMySQLCursorBufferedDict,
8: CMySQLCursorNamedTuple,
9: CMySQLCursorBufferedNamedTuple,
16: CMySQLCursorPrepared
}
try:
return (types[cursor_type])(self)
except KeyError:
args = ('buffered', 'raw', 'dictionary', 'named_tuple', 'prepared')
raise ValueError('Cursor not available with given criteria: ' +
', '.join([args[i] for i in range(5)
if cursor_type & (1 << i) != 0]))
@property
def num_rows(self):
"""Returns number of rows of current result set"""
if not self._cmysql.have_result_set:
raise errors.InterfaceError("No result set")
return self._cmysql.num_rows()
@property
def warning_count(self):
"""Returns number of warnings"""
if not self._cmysql:
return 0
return self._cmysql.warning_count()
@property
def result_set_available(self):
"""Check if a result set is available"""
if not self._cmysql:
return False
return self._cmysql.have_result_set
@property
def unread_result(self):
"""Check if there are unread results or rows"""
return self.result_set_available
@property
def more_results(self):
"""Check if there are more results"""
return self._cmysql.more_results()
def prepare_for_mysql(self, params):
"""Prepare parameters for statements
This method is use by cursors to prepared parameters found in the
list (or tuple) params.
Returns dict.
"""
if isinstance(params, (list, tuple)):
if self.converter:
result = [
self.converter.quote(
self.converter.escape(
self.converter.to_mysql(value)
)
) for value in params
]
else:
result = self._cmysql.convert_to_mysql(*params)
elif isinstance(params, dict):
result = {}
if self.converter:
for key, value in params.items():
result[key] = self.converter.quote(
self.converter.escape(
self.converter.to_mysql(value)
)
)
else:
for key, value in params.items():
result[key] = self._cmysql.convert_to_mysql(value)[0]
else:
raise errors.ProgrammingError(
f"Could not process parameters: {type(params).__name__}({params}),"
" it must be of type list, tuple or dict")
return result
def consume_results(self):
"""Consume the current result
This method consume the result by reading (consuming) all rows.
"""
self._cmysql.consume_result()
def cmd_change_user(self, username='', password='', database='',
charset=45, password1='', password2='', password3='',
oci_config_file=None):
"""Change the current logged in user"""
try:
self._cmysql.change_user(
username,
password,
database,
password1,
password2,
password3,
oci_config_file)
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
self._charset_id = charset
self._post_connection()
def cmd_reset_connection(self):
"""Resets the session state without re-authenticating
Works only for MySQL server 5.7.3 or later.
"""
if self._server_version < (5, 7, 3):
raise errors.NotSupportedError("MySQL version 5.7.2 and "
"earlier does not support "
"COM_RESET_CONNECTION.")
try:
self._cmysql.reset_connection()
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
self._post_connection()
def cmd_refresh(self, options):
"""Send the Refresh command to the MySQL server"""
try:
self._cmysql.refresh(options)
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
return self.fetch_eof_status()
def cmd_quit(self):
"""Close the current connection with the server"""
self.close()
def cmd_shutdown(self, shutdown_type=None):
"""Shut down the MySQL Server"""
if not self._cmysql:
raise errors.OperationalError("MySQL Connection not available")
if shutdown_type:
if not ShutdownType.get_info(shutdown_type):
raise errors.InterfaceError("Invalid shutdown type")
level = shutdown_type
else:
level = ShutdownType.SHUTDOWN_DEFAULT
try:
self._cmysql.shutdown(level)
except MySQLInterfaceError as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
self.close()
def cmd_statistics(self):
"""Return statistics from the MySQL server"""
self.handle_unread_result()
try:
stat = self._cmysql.stat()
return MySQLProtocol().parse_statistics(stat, with_header=False)
except (MySQLInterfaceError, errors.InterfaceError) as exc:
raise errors.get_mysql_exception(msg=exc.msg, errno=exc.errno,
sqlstate=exc.sqlstate)
def cmd_process_kill(self, mysql_pid):
"""Kill a MySQL process"""
if not isinstance(mysql_pid, int):
raise ValueError("MySQL PID must be int")
self.info_query("KILL {0}".format(mysql_pid))
def handle_unread_result(self, prepared=False):
"""Check whether there is an unread result"""
unread_result = self._unread_result if prepared is True \
else self.unread_result
if self.can_consume_results:
self.consume_results()
elif unread_result:
raise errors.InternalError("Unread result found")
def reset_session(self, user_variables=None, session_variables=None):
"""Clears the current active session
This method resets the session state, if the MySQL server is 5.7.3
or later active session will be reset without re-authenticating.
For other server versions session will be reset by re-authenticating.
It is possible to provide a sequence of variables and their values to
be set after clearing the session. This is possible for both user
defined variables and session variables.
This method takes two arguments user_variables and session_variables
which are dictionaries.
Raises OperationalError if not connected, InternalError if there are
unread results and InterfaceError on errors.
"""
if not self.is_connected():
raise errors.OperationalError("MySQL Connection not available.")
try:
self.cmd_reset_connection()
except (errors.NotSupportedError, NotImplementedError):
if self._compress:
raise errors.NotSupportedError(
"Reset session is not supported with compression for "
"MySQL server version 5.7.2 or earlier.")
elif self._server_version < (5, 7, 3):
raise errors.NotSupportedError(
"Reset session is not supported with MySQL server "
"version 5.7.2 or earlier.")
else:
self.cmd_change_user(self._user, self._password,
self._database, self._charset_id,
self._password1, self._password2,
self._password3, self._oci_config_file)
if user_variables or session_variables:
cur = self.cursor()
if user_variables:
for key, value in user_variables.items():
cur.execute("SET @`{0}` = %s".format(key), (value,))
if session_variables:
for key, value in session_variables.items():
cur.execute("SET SESSION `{0}` = %s".format(key), (value,))
cur.close()

View File

@@ -0,0 +1,614 @@
# Copyright (c) 2009, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Converting MySQL and Python types
"""
import datetime
import struct
import time
from decimal import Decimal
from .constants import FieldType, FieldFlag, CharacterSet
from .utils import NUMERIC_TYPES
from .custom_types import HexLiteral
CONVERT_ERROR = "Could not convert '{value}' to python {pytype}"
class MySQLConverterBase(object):
"""Base class for conversion classes
All class dealing with converting to and from MySQL data types must
be a subclass of this class.
"""
def __init__(self, charset='utf8', use_unicode=True, str_fallback=False):
self.python_types = None
self.mysql_types = None
self.charset = None
self.charset_id = 0
self.use_unicode = None
self.set_charset(charset)
self.use_unicode = use_unicode
self.str_fallback = str_fallback
self._cache_field_types = {}
def set_charset(self, charset):
"""Set character set"""
if charset == 'utf8mb4':
charset = 'utf8'
if charset is not None:
self.charset = charset
else:
# default to utf8
self.charset = 'utf8'
self.charset_id = CharacterSet.get_charset_info(self.charset)[0]
def set_unicode(self, value=True):
"""Set whether to use Unicode"""
self.use_unicode = value
def to_mysql(self, value):
"""Convert Python data type to MySQL"""
type_name = value.__class__.__name__.lower()
try:
return getattr(self, "_{0}_to_mysql".format(type_name))(value)
except AttributeError:
return value
def to_python(self, vtype, value):
"""Convert MySQL data type to Python"""
if (value == b'\x00' or value is None) and vtype[1] != FieldType.BIT:
# Don't go further when we hit a NULL value
return None
if not self._cache_field_types:
self._cache_field_types = {}
for name, info in FieldType.desc.items():
try:
self._cache_field_types[info[0]] = getattr(
self, '_{0}_to_python'.format(name))
except AttributeError:
# We ignore field types which has no method
pass
try:
return self._cache_field_types[vtype[1]](value, vtype)
except KeyError:
return value
def escape(self, value):
"""Escape buffer for sending to MySQL"""
return value
def quote(self, buf):
"""Quote buffer for sending to MySQL"""
return str(buf)
class MySQLConverter(MySQLConverterBase):
"""Default conversion class for MySQL Connector/Python.
o escape method: for escaping values send to MySQL
o quoting method: for quoting values send to MySQL in statements
o conversion mapping: maps Python and MySQL data types to
function for converting them.
Whenever one needs to convert values differently, a converter_class
argument can be given while instantiating a new connection like
cnx.connect(converter_class=CustomMySQLConverterClass).
"""
def __init__(self, charset=None, use_unicode=True, str_fallback=False):
MySQLConverterBase.__init__(self, charset, use_unicode, str_fallback)
self._cache_field_types = {}
def escape(self, value):
"""
Escapes special characters as they are expected to by when MySQL
receives them.
As found in MySQL source mysys/charset.c
Returns the value if not a string, or the escaped string.
"""
if value is None:
return value
elif isinstance(value, NUMERIC_TYPES):
return value
if isinstance(value, (bytes, bytearray)):
value = value.replace(b'\\', b'\\\\')
value = value.replace(b'\n', b'\\n')
value = value.replace(b'\r', b'\\r')
value = value.replace(b'\047', b'\134\047') # single quotes
value = value.replace(b'\042', b'\134\042') # double quotes
value = value.replace(b'\032', b'\134\032') # for Win32
else:
value = value.replace('\\', '\\\\')
value = value.replace('\n', '\\n')
value = value.replace('\r', '\\r')
value = value.replace('\047', '\134\047') # single quotes
value = value.replace('\042', '\134\042') # double quotes
value = value.replace('\032', '\134\032') # for Win32
return value
def quote(self, buf):
"""
Quote the parameters for commands. General rules:
o numbers are returns as bytes using ascii codec
o None is returned as bytearray(b'NULL')
o Everything else is single quoted '<buf>'
Returns a bytearray object.
"""
if isinstance(buf, NUMERIC_TYPES):
return str(buf).encode('ascii')
elif isinstance(buf, type(None)):
return bytearray(b"NULL")
return bytearray(b"'" + buf + b"'")
def to_mysql(self, value):
"""Convert Python data type to MySQL"""
type_name = value.__class__.__name__.lower()
try:
return getattr(self, "_{0}_to_mysql".format(type_name))(value)
except AttributeError:
if self.str_fallback:
return str(value).encode()
raise TypeError("Python '{0}' cannot be converted to a "
"MySQL type".format(type_name))
def to_python(self, vtype, value):
"""Convert MySQL data type to Python"""
if value == 0 and vtype[1] != FieldType.BIT: # \x00
# Don't go further when we hit a NULL value
return None
if value is None:
return None
if not self._cache_field_types:
self._cache_field_types = {}
for name, info in FieldType.desc.items():
try:
self._cache_field_types[info[0]] = getattr(
self, '_{0}_to_python'.format(name))
except AttributeError:
# We ignore field types which has no method
pass
try:
return self._cache_field_types[vtype[1]](value, vtype)
except KeyError:
# If one type is not defined, we just return the value as str
try:
return value.decode('utf-8')
except UnicodeDecodeError:
return value
except ValueError as err:
raise ValueError("%s (field %s)" % (err, vtype[0]))
except TypeError as err:
raise TypeError("%s (field %s)" % (err, vtype[0]))
except:
raise
def _int_to_mysql(self, value):
"""Convert value to int"""
return int(value)
def _long_to_mysql(self, value):
"""Convert value to int"""
return int(value)
def _float_to_mysql(self, value):
"""Convert value to float"""
return float(value)
def _str_to_mysql(self, value):
"""Convert value to string"""
return self._unicode_to_mysql(value)
def _unicode_to_mysql(self, value):
"""Convert unicode"""
charset = self.charset
charset_id = self.charset_id
if charset == 'binary':
charset = 'utf8'
charset_id = CharacterSet.get_charset_info(charset)[0]
encoded = value.encode(charset)
if charset_id in CharacterSet.slash_charsets:
if b'\x5c' in encoded:
return HexLiteral(value, charset)
return encoded
def _bytes_to_mysql(self, value):
"""Convert value to bytes"""
return value
def _bytearray_to_mysql(self, value):
"""Convert value to bytes"""
return bytes(value)
def _bool_to_mysql(self, value):
"""Convert value to boolean"""
if value:
return 1
return 0
def _nonetype_to_mysql(self, value):
"""
This would return what None would be in MySQL, but instead we
leave it None and return it right away. The actual conversion
from None to NULL happens in the quoting functionality.
Return None.
"""
return None
def _datetime_to_mysql(self, value):
"""
Converts a datetime instance to a string suitable for MySQL.
The returned string has format: %Y-%m-%d %H:%M:%S[.%f]
If the instance isn't a datetime.datetime type, it return None.
Returns a bytes.
"""
if value.microsecond:
fmt = '{0:04d}-{1:02d}-{2:02d} {3:02d}:{4:02d}:{5:02d}.{6:06d}'
return fmt.format(
value.year, value.month, value.day,
value.hour, value.minute, value.second,
value.microsecond).encode('ascii')
fmt = '{0:04d}-{1:02d}-{2:02d} {3:02d}:{4:02d}:{5:02d}'
return fmt.format(
value.year, value.month, value.day,
value.hour, value.minute, value.second).encode('ascii')
def _date_to_mysql(self, value):
"""
Converts a date instance to a string suitable for MySQL.
The returned string has format: %Y-%m-%d
If the instance isn't a datetime.date type, it return None.
Returns a bytes.
"""
return '{0:04d}-{1:02d}-{2:02d}'.format(value.year, value.month,
value.day).encode('ascii')
def _time_to_mysql(self, value):
"""
Converts a time instance to a string suitable for MySQL.
The returned string has format: %H:%M:%S[.%f]
If the instance isn't a datetime.time type, it return None.
Returns a bytes.
"""
if value.microsecond:
return value.strftime('%H:%M:%S.%f').encode('ascii')
return value.strftime('%H:%M:%S').encode('ascii')
def _struct_time_to_mysql(self, value):
"""
Converts a time.struct_time sequence to a string suitable
for MySQL.
The returned string has format: %Y-%m-%d %H:%M:%S
Returns a bytes or None when not valid.
"""
return time.strftime('%Y-%m-%d %H:%M:%S', value).encode('ascii')
def _timedelta_to_mysql(self, value):
"""
Converts a timedelta instance to a string suitable for MySQL.
The returned string has format: %H:%M:%S
Returns a bytes.
"""
seconds = abs(value.days * 86400 + value.seconds)
if value.microseconds:
fmt = '{0:02d}:{1:02d}:{2:02d}.{3:06d}'
if value.days < 0:
mcs = 1000000 - value.microseconds
seconds -= 1
else:
mcs = value.microseconds
else:
fmt = '{0:02d}:{1:02d}:{2:02d}'
if value.days < 0:
fmt = '-' + fmt
(hours, remainder) = divmod(seconds, 3600)
(mins, secs) = divmod(remainder, 60)
if value.microseconds:
result = fmt.format(hours, mins, secs, mcs)
else:
result = fmt.format(hours, mins, secs)
return result.encode('ascii')
def _decimal_to_mysql(self, value):
"""
Converts a decimal.Decimal instance to a string suitable for
MySQL.
Returns a bytes or None when not valid.
"""
if isinstance(value, Decimal):
return str(value).encode('ascii')
return None
def row_to_python(self, row, fields):
"""Convert a MySQL text result row to Python types
The row argument is a sequence containing text result returned
by a MySQL server. Each value of the row is converted to the
using the field type information in the fields argument.
Returns a tuple.
"""
i = 0
result = [None]*len(fields)
if not self._cache_field_types:
self._cache_field_types = {}
for name, info in FieldType.desc.items():
try:
self._cache_field_types[info[0]] = getattr(
self, '_{0}_to_python'.format(name))
except AttributeError:
# We ignore field types which has no method
pass
for field in fields:
field_type = field[1]
if (row[i] == 0 and field_type != FieldType.BIT) or row[i] is None:
# Don't convert NULL value
i += 1
continue
try:
result[i] = self._cache_field_types[field_type](row[i], field)
except KeyError:
# If one type is not defined, we just return the value as str
try:
result[i] = row[i].decode('utf-8')
except UnicodeDecodeError:
result[i] = row[i]
except (ValueError, TypeError) as err:
err.message = "{0} (field {1})".format(str(err), field[0])
raise
i += 1
return tuple(result)
def _FLOAT_to_python(self, value, desc=None): # pylint: disable=C0103
"""
Returns value as float type.
"""
return float(value)
_DOUBLE_to_python = _FLOAT_to_python
def _INT_to_python(self, value, desc=None): # pylint: disable=C0103
"""
Returns value as int type.
"""
return int(value)
_TINY_to_python = _INT_to_python
_SHORT_to_python = _INT_to_python
_INT24_to_python = _INT_to_python
_LONG_to_python = _INT_to_python
_LONGLONG_to_python = _INT_to_python
def _DECIMAL_to_python(self, value, desc=None): # pylint: disable=C0103
"""
Returns value as a decimal.Decimal.
"""
val = value.decode(self.charset)
return Decimal(val)
_NEWDECIMAL_to_python = _DECIMAL_to_python
def _str(self, value, desc=None):
"""
Returns value as str type.
"""
return str(value)
def _BIT_to_python(self, value, dsc=None): # pylint: disable=C0103
"""Returns BIT columntype as integer"""
int_val = value
if len(int_val) < 8:
int_val = b'\x00' * (8 - len(int_val)) + int_val
return struct.unpack('>Q', int_val)[0]
def _DATE_to_python(self, value, dsc=None): # pylint: disable=C0103
"""Converts TIME column MySQL to a python datetime.datetime type.
Raises ValueError if the value can not be converted.
Returns DATE column type as datetime.date type.
"""
if isinstance(value, datetime.date):
return value
try:
parts = value.split(b'-')
if len(parts) != 3:
raise ValueError("invalid datetime format: {} len: {}"
"".format(parts, len(parts)))
try:
return datetime.date(int(parts[0]), int(parts[1]), int(parts[2]))
except ValueError:
return None
except (IndexError, ValueError):
raise ValueError(
"Could not convert {0} to python datetime.timedelta".format(
value))
_NEWDATE_to_python = _DATE_to_python
def _TIME_to_python(self, value, dsc=None): # pylint: disable=C0103
"""Converts TIME column value to python datetime.time value type.
Converts the TIME column MySQL type passed as bytes to a python
datetime.datetime type.
Raises ValueError if the value can not be converted.
Returns datetime.time type.
"""
try:
(hms, mcs) = value.split(b'.')
mcs = int(mcs.ljust(6, b'0'))
except (TypeError, ValueError):
hms = value
mcs = 0
try:
(hours, mins, secs) = [int(d) for d in hms.split(b':')]
if value[0] == 45 or value[0] == '-':
mins, secs, mcs = -mins, -secs, -mcs
return datetime.timedelta(hours=hours, minutes=mins,
seconds=secs, microseconds=mcs)
except (IndexError, TypeError, ValueError):
raise ValueError(CONVERT_ERROR.format(value=value,
pytype="datetime.timedelta"))
def _DATETIME_to_python(self, value, dsc=None): # pylint: disable=C0103
""""Converts DATETIME column value to python datetime.time value type.
Converts the DATETIME column MySQL type passed as bytes to a python
datetime.datetime type.
Returns: datetime.datetime type.
"""
if isinstance(value, datetime.datetime):
return value
datetime_val = None
try:
(date_, time_) = value.split(b' ')
if len(time_) > 8:
(hms, mcs) = time_.split(b'.')
mcs = int(mcs.ljust(6, b'0'))
else:
hms = time_
mcs = 0
dtval = [int(i) for i in date_.split(b'-')] + \
[int(i) for i in hms.split(b':')] + [mcs, ]
if len(dtval) < 6:
raise ValueError("invalid datetime format: {} len: {}"
"".format(dtval, len(dtval)))
else:
# Note that by default MySQL accepts invalid timestamps
# (this is also backward compatibility).
# Traditionaly C/py returns None for this well formed but
# invalid datetime for python like '0000-00-00 HH:MM:SS'.
try:
datetime_val = datetime.datetime(*dtval)
except ValueError:
return None
except (IndexError, TypeError):
raise ValueError(CONVERT_ERROR.format(value=value,
pytype="datetime.timedelta"))
return datetime_val
_TIMESTAMP_to_python = _DATETIME_to_python
def _YEAR_to_python(self, value, desc=None): # pylint: disable=C0103
"""Returns YEAR column type as integer"""
try:
year = int(value)
except ValueError:
raise ValueError("Failed converting YEAR to int (%s)" % value)
return year
def _SET_to_python(self, value, dsc=None): # pylint: disable=C0103
"""Returns SET column type as set
Actually, MySQL protocol sees a SET as a string type field. So this
code isn't called directly, but used by STRING_to_python() method.
Returns SET column type as a set.
"""
set_type = None
val = value.decode(self.charset)
if not val:
return set()
try:
set_type = set(val.split(','))
except ValueError:
raise ValueError("Could not convert set %s to a sequence." % value)
return set_type
def _STRING_to_python(self, value, dsc=None): # pylint: disable=C0103
"""
Note that a SET is a string too, but using the FieldFlag we can see
whether we have to split it.
Returns string typed columns as string type.
"""
if self.charset == "binary":
return value
if dsc is not None:
if dsc[1] == FieldType.JSON and self.use_unicode:
return value.decode(self.charset)
if dsc[7] & FieldFlag.SET:
return self._SET_to_python(value, dsc)
if dsc[8] == 63: # 'binary' charset
return value
if isinstance(value, (bytes, bytearray)) and self.use_unicode:
return value.decode(self.charset)
return value
_VAR_STRING_to_python = _STRING_to_python
_JSON_to_python = _STRING_to_python
def _BLOB_to_python(self, value, dsc=None): # pylint: disable=C0103
"""Convert BLOB data type to Python."""
if dsc is not None:
if dsc[7] & FieldFlag.BLOB and dsc[7] & FieldFlag.BINARY:
return bytes(value)
return self._STRING_to_python(value, dsc)
_LONG_BLOB_to_python = _BLOB_to_python
_MEDIUM_BLOB_to_python = _BLOB_to_python
_TINY_BLOB_to_python = _BLOB_to_python

View File

@@ -0,0 +1,50 @@
# Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Custom Python types used by MySQL Connector/Python"""
import sys
class HexLiteral(str):
"""Class holding MySQL hex literals"""
def __new__(cls, str_, charset='utf8'):
if sys.version_info[0] == 2:
hexed = ["%02x" % ord(i) for i in str_.encode(charset)]
else:
hexed = ["%02x" % i for i in str_.encode(charset)]
obj = str.__new__(cls, ''.join(hexed))
obj.charset = charset
obj.original = str_
return obj
def __str__(self):
return '0x' + self

View File

@@ -0,0 +1,80 @@
# Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""
This module implements some constructors and singletons as required by the
DB API v2.0 (PEP-249).
"""
# Python Db API v2
apilevel = '2.0'
threadsafety = 1
paramstyle = 'pyformat'
import time
import datetime
from . import constants
class _DBAPITypeObject(object):
def __init__(self, *values):
self.values = values
def __eq__(self, other):
if other in self.values:
return True
else:
return False
def __ne__(self, other):
if other in self.values:
return False
else:
return True
Date = datetime.date
Time = datetime.time
Timestamp = datetime.datetime
def DateFromTicks(ticks):
return Date(*time.localtime(ticks)[:3])
def TimeFromTicks(ticks):
return Time(*time.localtime(ticks)[3:6])
def TimestampFromTicks(ticks):
return Timestamp(*time.localtime(ticks)[:6])
Binary = bytes
STRING = _DBAPITypeObject(*constants.FieldType.get_string_types())
BINARY = _DBAPITypeObject(*constants.FieldType.get_binary_types())
NUMBER = _DBAPITypeObject(*constants.FieldType.get_number_types())
DATETIME = _DBAPITypeObject(*constants.FieldType.get_timestamp_types())
ROWID = _DBAPITypeObject()

View File

@@ -0,0 +1,538 @@
# Copyright (c) 2020, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Django database Backend using MySQL Connector/Python.
This Django database backend is heavily based on the MySQL backend from Django.
Changes include:
* Support for microseconds (MySQL 5.6.3 and later)
* Using INFORMATION_SCHEMA where possible
* Using new defaults for, for example SQL_AUTO_IS_NULL
Requires and comes with MySQL Connector/Python v8.0.22 and later:
http://dev.mysql.com/downloads/connector/python/
"""
import warnings
import sys
from datetime import datetime
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db import utils
from django.utils.functional import cached_property
from django.utils import dateparse, timezone
try:
import mysql.connector
from mysql.connector.conversion import MySQLConverter
except ImportError as err:
raise ImproperlyConfigured(
"Error loading mysql.connector module: {0}".format(err))
try:
from _mysql_connector import datetime_to_mysql, time_to_mysql
except ImportError:
HAVE_CEXT = False
else:
HAVE_CEXT = True
from .client import DatabaseClient
from .creation import DatabaseCreation
from .introspection import DatabaseIntrospection
from .validation import DatabaseValidation
from .features import DatabaseFeatures
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
Error = mysql.connector.Error
DatabaseError = mysql.connector.DatabaseError
NotSupportedError = mysql.connector.NotSupportedError
OperationalError = mysql.connector.OperationalError
ProgrammingError = mysql.connector.ProgrammingError
def adapt_datetime_with_timezone_support(value):
# Equivalent to DateTimeField.get_db_prep_value. Used only by raw SQL.
if settings.USE_TZ:
if timezone.is_naive(value):
warnings.warn("MySQL received a naive datetime (%s)"
" while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
value = value.astimezone(timezone.utc).replace(tzinfo=None)
if HAVE_CEXT:
return datetime_to_mysql(value)
else:
return value.strftime("%Y-%m-%d %H:%M:%S.%f")
class CursorWrapper:
"""Wrapper around MySQL Connector/Python's cursor class.
The cursor class is defined by the options passed to MySQL
Connector/Python. If buffered option is True in those options,
MySQLCursorBuffered will be used.
"""
codes_for_integrityerror = (
1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range
3819, # CHECK constraint is violated
4025, # CHECK constraint failed
)
def __init__(self, cursor):
self.cursor = cursor
def _adapt_execute_args_dict(self, args):
if not args:
return args
new_args = dict(args)
for key, value in args.items():
if isinstance(value, datetime):
new_args[key] = adapt_datetime_with_timezone_support(value)
return new_args
def _adapt_execute_args(self, args):
if not args:
return args
new_args = list(args)
for i, arg in enumerate(args):
if isinstance(arg, datetime):
new_args[i] = adapt_datetime_with_timezone_support(arg)
return tuple(new_args)
def execute(self, query, args=None):
"""Executes the given operation
This wrapper method around the execute()-method of the cursor is
mainly needed to re-raise using different exceptions.
"""
if isinstance(args, dict):
new_args = self._adapt_execute_args_dict(args)
else:
new_args = self._adapt_execute_args(args)
try:
return self.cursor.execute(query, new_args)
except mysql.connector.OperationalError as e:
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def executemany(self, query, args):
"""Executes the given operation
This wrapper method around the executemany()-method of the cursor is
mainly needed to re-raise using different exceptions.
"""
try:
return self.cursor.executemany(query, args)
except mysql.connector.OperationalError as e:
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def __getattr__(self, attr):
"""Return attribute of wrapped cursor"""
return getattr(self.cursor, attr)
def __iter__(self):
"""Returns iterator over wrapped cursor"""
return iter(self.cursor)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql'
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'integer AUTO_INCREMENT',
'BigAutoField': 'bigint AUTO_INCREMENT',
'BinaryField': 'longblob',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime(6)',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'bigint',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'char(15)',
'GenericIPAddressField': 'char(39)',
'JSONField': 'json',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint UNSIGNED',
'PositiveIntegerField': 'integer UNSIGNED',
'PositiveSmallIntegerField': 'smallint UNSIGNED',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallint AUTO_INCREMENT',
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time(6)',
'UUIDField': 'char(32)',
}
# For these data types:
# - MySQL < 8.0.13 doesn't accept default values and
# implicitly treat them as nullable
# - all versions of MySQL doesn't support full width database
# indexes
_limited_data_types = (
'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text',
'mediumtext', 'longtext', 'json',
)
operators = {
'exact': '= %s',
'iexact': 'LIKE %s',
'contains': 'LIKE BINARY %s',
'icontains': 'LIKE %s',
'regex': 'REGEXP BINARY %s',
'iregex': 'REGEXP %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE BINARY %s',
'endswith': 'LIKE BINARY %s',
'istartswith': 'LIKE %s',
'iendswith': 'LIKE %s',
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
'contains': "LIKE BINARY CONCAT('%%', {}, '%%')",
'icontains': "LIKE CONCAT('%%', {}, '%%')",
'startswith': "LIKE BINARY CONCAT({}, '%%')",
'istartswith': "LIKE CONCAT({}, '%%')",
'endswith': "LIKE BINARY CONCAT('%%', {})",
'iendswith': "LIKE CONCAT('%%', {})",
}
isolation_levels = {
'read uncommitted',
'read committed',
'repeatable read',
'serializable',
}
Database = mysql.connector
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
try:
self._use_pure = self.settings_dict['OPTIONS']['use_pure']
except KeyError:
self._use_pure = not HAVE_CEXT
self.converter = DjangoMySQLConverter()
def __getattr__(self, attr):
if attr.startswith("mysql_is"):
return False
raise AttributeError
def get_connection_params(self):
kwargs = {
'charset': 'utf8',
'use_unicode': True,
'buffered': False,
'consume_results': True,
}
settings_dict = self.settings_dict
if settings_dict['USER']:
kwargs['user'] = settings_dict['USER']
if settings_dict['NAME']:
kwargs['database'] = settings_dict['NAME']
if settings_dict['PASSWORD']:
kwargs['passwd'] = settings_dict['PASSWORD']
if settings_dict['HOST'].startswith('/'):
kwargs['unix_socket'] = settings_dict['HOST']
elif settings_dict['HOST']:
kwargs['host'] = settings_dict['HOST']
if settings_dict['PORT']:
kwargs['port'] = int(settings_dict['PORT'])
# Raise exceptions for database warnings if DEBUG is on
kwargs['raise_on_warnings'] = settings.DEBUG
kwargs['client_flags'] = [
# Need potentially affected rows on UPDATE
mysql.connector.constants.ClientFlag.FOUND_ROWS,
]
try:
kwargs.update(settings_dict['OPTIONS'])
except KeyError:
# OPTIONS missing is OK
pass
return kwargs
def get_new_connection(self, conn_params):
if not 'converter_class' in conn_params:
conn_params['converter_class'] = DjangoMySQLConverter
cnx = mysql.connector.connect(**conn_params)
return cnx
def init_connection_state(self):
assignments = []
if self.features.is_sql_auto_is_null_enabled:
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
# a recently inserted row will return when the field is tested
# for NULL. Disabling this brings this aspect of MySQL in line
# with SQL standards.
assignments.append('SET SQL_AUTO_IS_NULL = 0')
if assignments:
with self.cursor() as cursor:
cursor.execute('; '.join(assignments))
if 'AUTOCOMMIT' in self.settings_dict:
try:
self.set_autocommit(self.settings_dict['AUTOCOMMIT'])
except AttributeError:
self._set_autocommit(self.settings_dict['AUTOCOMMIT'])
def create_cursor(self, name=None):
cursor = self.connection.cursor()
return CursorWrapper(cursor)
def _rollback(self):
try:
BaseDatabaseWrapper._rollback(self)
except NotSupportedError:
pass
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit = autocommit
def disable_constraint_checking(self):
"""
Disable foreign key checks, primarily for use in adding rows with
forward references. Always return True to indicate constraint checks
need to be re-enabled.
"""
with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=0')
return True
def enable_constraint_checking(self):
"""
Re-enable foreign key checks after they have been disabled.
"""
# Override needs_rollback in case constraint_checks_disabled is
# nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=1')
finally:
self.needs_rollback = needs_rollback
def check_constraints(self, table_names=None):
"""
Check each table name in `table_names` for rows with invalid foreign
key references. This method is intended to be used in conjunction with
`disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint
checks were off.
"""
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
""" % (
primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self):
try:
self.connection.ping()
except Error:
return False
else:
return True
@cached_property
def display_name(self):
return 'MySQL'
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
check_constraints = {
'PositiveBigIntegerField': '`%(column)s` >= 0',
'PositiveIntegerField': '`%(column)s` >= 0',
'PositiveSmallIntegerField': '`%(column)s` >= 0',
}
return check_constraints
return {}
@cached_property
def mysql_server_data(self):
with self.temporary_connection() as cursor:
# Select some server variables and test if the time zone
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
# timezone isn't loaded into the mysql.time_zone table.
cursor.execute("""
SELECT VERSION(),
@@sql_mode,
@@default_storage_engine,
@@sql_auto_is_null,
@@lower_case_table_names,
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
""")
row = cursor.fetchone()
return {
'version': row[0],
'sql_mode': row[1],
'default_storage_engine': row[2],
'sql_auto_is_null': bool(row[3]),
'lower_case_table_names': bool(row[4]),
'has_zoneinfo_database': bool(row[5]),
}
@cached_property
def mysql_server_info(self):
with self.temporary_connection() as cursor:
cursor.execute('SELECT VERSION()')
return cursor.fetchone()[0]
@cached_property
def mysql_version(self):
config = self.get_connection_params()
with mysql.connector.connect(**config) as conn:
server_version = conn.get_server_version()
return server_version
@cached_property
def sql_mode(self):
with self.cursor() as cursor:
cursor.execute('SELECT @@sql_mode')
sql_mode = cursor.fetchone()
return set(sql_mode[0].split(',') if sql_mode else ())
@property
def use_pure(self):
return self._use_pure
class DjangoMySQLConverter(MySQLConverter):
"""Custom converter for Django."""
def _TIME_to_python(self, value, dsc=None):
"""Return MySQL TIME data type as datetime.time()
Returns datetime.time()
"""
return dateparse.parse_time(value.decode('utf-8'))
def __DATETIME_to_python(self, value, dsc=None):
"""Connector/Python always returns naive datetime.datetime
Connector/Python always returns naive timestamps since MySQL has
no time zone support. Since Django needs non-naive, we need to add
the UTC time zone.
Returns datetime.datetime()
"""
if not value:
return None
dt = MySQLConverter._DATETIME_to_python(self, value)
if dt is None:
return None
if settings.USE_TZ and timezone.is_naive(dt):
dt = dt.replace(tzinfo=timezone.utc)
return dt
def _safestring_to_mysql(self, value):
return self._str_to_mysql(value)
def _safetext_to_mysql(self, value):
return self._str_to_mysql(value)
def _safebytes_to_mysql(self, value):
return self._bytes_to_mysql(value)

View File

@@ -0,0 +1,79 @@
# Copyright (c) 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import subprocess
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'mysql'
@classmethod
def settings_to_cmd_args(cls, settings_dict):
args = [cls.executable_name]
db = settings_dict['OPTIONS'].get('database', settings_dict['NAME'])
user = settings_dict['OPTIONS'].get('user',
settings_dict['USER'])
passwd = settings_dict['OPTIONS'].get('password',
settings_dict['PASSWORD'])
host = settings_dict['OPTIONS'].get('host', settings_dict['HOST'])
port = settings_dict['OPTIONS'].get('port', settings_dict['PORT'])
defaults_file = settings_dict['OPTIONS'].get('read_default_file')
# --defaults-file should always be the first option
if defaults_file:
args.append('--defaults-file={0}'.format(defaults_file))
# We force SQL_MODE to TRADITIONAL
args.append('--init-command=SET @@session.SQL_MODE=TRADITIONAL')
if user:
args.append('--user={0}'.format(user))
if passwd:
args.append('--password={0}'.format(passwd))
if host:
if '/' in host:
args.append('--socket={0}'.format(host))
else:
args.append('--host={0}'.format(host))
if port:
args.append('--port={0}'.format(port))
if db:
args.append('--database={0}'.format(db))
return args
def runshell(self):
args = DatabaseClient.settings_to_cmd_args(
self.connection.settings_dict)
subprocess.call(args)

View File

@@ -0,0 +1,35 @@
# Copyright (c) 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from django.db.backends.mysql.compiler import (
SQLCompiler,
SQLInsertCompiler,
SQLDeleteCompiler,
SQLUpdateCompiler,
SQLAggregateCompiler
)

View File

@@ -0,0 +1,29 @@
# Copyright (c) 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from django.db.backends.mysql.creation import DatabaseCreation

View File

@@ -0,0 +1,44 @@
# Copyright (c) 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from django.db.backends.mysql.features import DatabaseFeatures as MySQLDatabaseFeatures
from django.utils.functional import cached_property
class DatabaseFeatures(MySQLDatabaseFeatures):
empty_fetchmany_value = []
@cached_property
def can_introspect_check_constraints(self):
return self.connection.mysql_version >= (8, 0, 16)
@cached_property
def supports_microsecond_precision(self):
if self.connection.mysql_version >= (5, 6, 3):
return True
return False

View File

@@ -0,0 +1,380 @@
# Copyright (c) 2020, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from collections import namedtuple
import sqlparse
from mysql.connector.constants import FieldType
from django import VERSION as DJANGO_VERSION
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.db.models import Index
from django.utils.datastructures import OrderedSet
FieldInfo = namedtuple(
'FieldInfo',
BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint')
)
if DJANGO_VERSION < (3, 2, 0):
InfoLine = namedtuple(
'InfoLine',
'col_name data_type max_len num_prec num_scale extra column_default '
'is_unsigned'
)
else:
InfoLine = namedtuple(
'InfoLine',
'col_name data_type max_len num_prec num_scale extra column_default '
'collation is_unsigned'
)
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
FieldType.BLOB: 'TextField',
FieldType.DECIMAL: 'DecimalField',
FieldType.NEWDECIMAL: 'DecimalField',
FieldType.DATE: 'DateField',
FieldType.DATETIME: 'DateTimeField',
FieldType.DOUBLE: 'FloatField',
FieldType.FLOAT: 'FloatField',
FieldType.INT24: 'IntegerField',
FieldType.LONG: 'IntegerField',
FieldType.LONGLONG: 'BigIntegerField',
FieldType.SHORT: 'SmallIntegerField',
FieldType.STRING: 'CharField',
FieldType.TIME: 'TimeField',
FieldType.TIMESTAMP: 'DateTimeField',
FieldType.TINY: 'IntegerField',
FieldType.TINY_BLOB: 'TextField',
FieldType.MEDIUM_BLOB: 'TextField',
FieldType.LONG_BLOB: 'TextField',
FieldType.VAR_STRING: 'CharField',
}
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if 'auto_increment' in description.extra:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
if description.is_unsigned:
if field_type == 'BigIntegerField':
return 'PositiveBigIntegerField'
elif field_type == 'IntegerField':
return 'PositiveIntegerField'
elif field_type == 'SmallIntegerField':
return 'PositiveSmallIntegerField'
# JSON data type is an alias for LONGTEXT in MariaDB, use check
# constraints clauses to introspect JSONField.
if description.has_json_constraint:
return 'JSONField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("SHOW FULL TABLES")
return [TableInfo(row[0], {'BASE TABLE': 't', 'VIEW': 'v'}.get(row[1]))
for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface."
"""
json_constraints = {}
# A default collation for the given table.
cursor.execute("""
SELECT table_collation
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
""", [table_name])
row = cursor.fetchone()
default_column_collation = row[0] if row else ''
# information_schema database gives more accurate results for some figures:
# - varchar length returned by cursor.description is an internal length,
# not visible length (#5725)
# - precision and scale (for decimal fields) (#5014)
# - auto_increment is not available in cursor.description
if DJANGO_VERSION < (3, 2, 0):
cursor.execute("""
SELECT
column_name, data_type, character_maximum_length,
numeric_precision, numeric_scale, extra, column_default,
CASE
WHEN column_type LIKE '%% unsigned' THEN 1
ELSE 0
END AS is_unsigned
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()
""", [table_name])
else:
cursor.execute("""
SELECT
column_name, data_type, character_maximum_length,
numeric_precision, numeric_scale, extra, column_default,
CASE
WHEN collation_name = %s THEN NULL
ELSE collation_name
END AS collation_name,
CASE
WHEN column_type LIKE '%% unsigned' THEN 1
ELSE 0
END AS is_unsigned
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()
""", [default_column_collation, table_name])
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
def to_int(i):
return int(i) if i is not None else i
fields = []
for line in cursor.description:
info = field_info[line[0]]
if DJANGO_VERSION < (3, 2, 0):
fields.append(FieldInfo(
*line[:3],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.extra,
info.is_unsigned,
line[0] in json_constraints
))
else:
fields.append(FieldInfo(
*line[:3],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.collation,
info.extra,
info.is_unsigned,
line[0] in json_constraints,
))
return fields
def get_indexes(self, cursor, table_name):
cursor.execute("SHOW INDEX FROM {0}"
"".format(self.connection.ops.quote_name(table_name)))
# Do a two-pass search for indexes: on first pass check which indexes
# are multicolumn, on second pass check which single-column indexes
# are present.
rows = list(cursor.fetchall())
multicol_indexes = set()
for row in rows:
if row[3] > 1:
multicol_indexes.add(row[2])
indexes = {}
for row in rows:
if row[2] in multicol_indexes:
continue
if row[4] not in indexes:
indexes[row[4]] = {'primary_key': False, 'unique': False}
# It's possible to have the unique and PK constraints in
# separate indexes.
if row[2] == 'PRIMARY':
indexes[row[4]]['primary_key'] = True
if not row[1]:
indexes[row[4]]['unique'] = True
return indexes
def get_primary_key_column(self, cursor, table_name):
"""
Returns the name of the primary key column for the given table
"""
for column in self.get_indexes(cursor, table_name).items():
if column[1]['primary_key']:
return column[0]
return None
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
if 'auto_increment' in field_info.extra:
# MySQL allows only one auto-increment column per table.
return [{'table': table_name, 'column': field_info.name}]
return []
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
constraints = self.get_key_columns(cursor, table_name)
relations = {}
for my_fieldname, other_table, other_field in constraints:
relations[my_fieldname] = (other_field, other_table)
return relations
def get_key_columns(self, cursor, table_name):
"""
Return a list of (column_name, referenced_table_name, referenced_column_name)
for all key columns in the given table.
"""
key_columns = []
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
key_columns.extend(cursor.fetchall())
return key_columns
def get_storage_engine(self, cursor, table_name):
"""
Retrieve the storage engine for a given table. Return the default
storage engine if the table doesn't exist.
"""
cursor.execute(
"SELECT engine "
"FROM information_schema.tables "
"WHERE table_name = %s", [table_name])
result = cursor.fetchone()
if not result:
return self.connection.features._mysql_storage_engine
return result[0]
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Get the actual constraint names and columns
name_query = """
SELECT kc.`constraint_name`, kc.`column_name`,
kc.`referenced_table_name`, kc.`referenced_column_name`
FROM information_schema.key_column_usage AS kc
WHERE
kc.table_schema = DATABASE() AND
kc.table_name = %s
ORDER BY kc.`ordinal_position`
"""
cursor.execute(name_query, [table_name])
for constraint, column, ref_table, ref_column in cursor.fetchall():
if constraint not in constraints:
constraints[constraint] = {
'columns': OrderedSet(),
'primary_key': False,
'unique': False,
'index': False,
'check': False,
'foreign_key': (ref_table, ref_column) if ref_column else None,
}
if self.connection.features.supports_index_column_ordering:
constraints[constraint]['orders'] = []
constraints[constraint]['columns'].add(column)
# Now get the constraint types
type_query = """
SELECT c.constraint_name, c.constraint_type
FROM information_schema.table_constraints AS c
WHERE
c.table_schema = DATABASE() AND
c.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, kind in cursor.fetchall():
if kind.lower() == "primary key":
constraints[constraint]['primary_key'] = True
constraints[constraint]['unique'] = True
elif kind.lower() == "unique":
constraints[constraint]['unique'] = True
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
unnamed_constraints_index = 0
columns = {info.name for info in self.get_table_description(cursor, table_name)}
type_query = """
SELECT cc.constraint_name, cc.check_clause
FROM
information_schema.check_constraints AS cc,
information_schema.table_constraints AS tc
WHERE
cc.constraint_schema = DATABASE() AND
tc.table_schema = cc.constraint_schema AND
cc.constraint_name = tc.constraint_name AND
tc.constraint_type = 'CHECK' AND
tc.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
constraint_columns = self._parse_constraint_columns(check_clause, columns)
# Ensure uniqueness of unnamed constraints. Unnamed unique
# and check columns constraints have the same name as
# a column.
if set(constraint_columns) == {constraint}:
unnamed_constraints_index += 1
constraint = '__unnamed_constraint_%s__' % unnamed_constraints_index
constraints[constraint] = {
'columns': constraint_columns,
'primary_key': False,
'unique': False,
'index': False,
'check': True,
'foreign_key': None,
}
# Now add in the indexes
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
for table, non_unique, index, colseq, column, order, type_ in [
x[:6] + (x[10],) for x in cursor.fetchall()
]:
if index not in constraints:
constraints[index] = {
'columns': OrderedSet(),
'primary_key': False,
'unique': False,
'check': False,
'foreign_key': None,
}
if self.connection.features.supports_index_column_ordering:
constraints[index]['orders'] = []
constraints[index]['index'] = True
constraints[index]['type'] = Index.suffix if type_ == 'BTREE' else type_.lower()
constraints[index]['columns'].add(column)
if self.connection.features.supports_index_column_ordering:
constraints[index]['orders'].append('DESC' if order == 'D' else 'ASC')
# Convert the sorted sets to lists
for constraint in constraints.values():
constraint['columns'] = list(constraint['columns'])
return constraints

View File

@@ -0,0 +1,87 @@
# Copyright (c) 2020, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from django.db.backends.mysql.operations import DatabaseOperations as MySQLDatabaseOperations
from django.conf import settings
from django.utils import timezone
try:
from _mysql_connector import datetime_to_mysql, time_to_mysql
except ImportError:
HAVE_CEXT = False
else:
HAVE_CEXT = True
class DatabaseOperations(MySQLDatabaseOperations):
compiler_module = "mysql.connector.django.compiler"
def regex_lookup(self, lookup_type):
if self.connection.mysql_version < (8, 0, 0):
if lookup_type == 'regex':
return '%s REGEXP BINARY %s'
return '%s REGEXP %s'
match_option = 'c' if lookup_type == 'regex' else 'i'
return "REGEXP_LIKE(%s, %s, '%s')" % match_option
def adapt_datetimefield_value(self, value):
return self.value_to_db_datetime(value)
def value_to_db_datetime(self, value):
if value is None:
return None
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
if settings.USE_TZ:
value = value.astimezone(timezone.utc).replace(tzinfo=None)
else:
raise ValueError(
"MySQL backend does not support timezone-aware times."
)
if not self.connection.features.supports_microsecond_precision:
value = value.replace(microsecond=0)
if not self.connection.use_pure:
return datetime_to_mysql(value)
return self.connection.converter.to_mysql(value)
def adapt_timefield_value(self, value):
return self.value_to_db_time(value)
def value_to_db_time(self, value):
if value is None:
return None
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware "
"times.")
if not self.connection.use_pure:
return time_to_mysql(value)
return self.connection.converter.to_mysql(value)

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from django.db.backends.mysql.schema import DatabaseSchemaEditor as MySQLDatabaseSchemaEditor
class DatabaseSchemaEditor(MySQLDatabaseSchemaEditor):
def quote_value(self, value):
self.connection.ensure_connection()
if isinstance(value, str):
value = value.replace('%', '%%')
quoted = self.connection.connection.converter.escape(value)
if isinstance(value, str) and isinstance(quoted, bytes):
quoted = quoted.decode()
return quoted

View File

@@ -0,0 +1,29 @@
# Copyright (c) 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from django.db.backends.mysql.validation import DatabaseValidation

View File

@@ -0,0 +1,306 @@
# Copyright (c) 2009, 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Python exceptions
"""
from . import utils
from .locales import get_client_error
# _CUSTOM_ERROR_EXCEPTIONS holds custom exceptions and is ued by the
# function custom_error_exception. _ERROR_EXCEPTIONS (at bottom of module)
# is similar, but hardcoded exceptions.
_CUSTOM_ERROR_EXCEPTIONS = {}
def custom_error_exception(error=None, exception=None):
"""Define custom exceptions for MySQL server errors
This function defines custom exceptions for MySQL server errors and
returns the current set customizations.
If error is a MySQL Server error number, then you have to pass also the
exception class.
The error argument can also be a dictionary in which case the key is
the server error number, and value the exception to be raised.
If none of the arguments are given, then custom_error_exception() will
simply return the current set customizations.
To reset the customizations, simply supply an empty dictionary.
Examples:
import mysql.connector
from mysql.connector import errorcode
# Server error 1028 should raise a DatabaseError
mysql.connector.custom_error_exception(
1028, mysql.connector.DatabaseError)
# Or using a dictionary:
mysql.connector.custom_error_exception({
1028: mysql.connector.DatabaseError,
1029: mysql.connector.OperationalError,
})
# Reset
mysql.connector.custom_error_exception({})
Returns a dictionary.
"""
global _CUSTOM_ERROR_EXCEPTIONS # pylint: disable=W0603
if isinstance(error, dict) and not error:
_CUSTOM_ERROR_EXCEPTIONS = {}
return _CUSTOM_ERROR_EXCEPTIONS
if not error and not exception:
return _CUSTOM_ERROR_EXCEPTIONS
if not isinstance(error, (int, dict)):
raise ValueError(
"The error argument should be either an integer or dictionary")
if isinstance(error, int):
error = {error: exception}
for errno, _exception in error.items():
if not isinstance(errno, int):
raise ValueError("error number should be an integer")
try:
if not issubclass(_exception, Exception):
raise TypeError
except TypeError:
raise ValueError("exception should be subclass of Exception")
_CUSTOM_ERROR_EXCEPTIONS[errno] = _exception
return _CUSTOM_ERROR_EXCEPTIONS
def get_mysql_exception(errno, msg=None, sqlstate=None):
"""Get the exception matching the MySQL error
This function will return an exception based on the SQLState. The given
message will be passed on in the returned exception.
The exception returned can be customized using the
mysql.connector.custom_error_exception() function.
Returns an Exception
"""
try:
return _CUSTOM_ERROR_EXCEPTIONS[errno](
msg=msg, errno=errno, sqlstate=sqlstate)
except KeyError:
# Error was not mapped to particular exception
pass
try:
return _ERROR_EXCEPTIONS[errno](
msg=msg, errno=errno, sqlstate=sqlstate)
except KeyError:
# Error was not mapped to particular exception
pass
if not sqlstate:
return DatabaseError(msg=msg, errno=errno)
try:
return _SQLSTATE_CLASS_EXCEPTION[sqlstate[0:2]](
msg=msg, errno=errno, sqlstate=sqlstate)
except KeyError:
# Return default InterfaceError
return DatabaseError(msg=msg, errno=errno, sqlstate=sqlstate)
def get_exception(packet):
"""Returns an exception object based on the MySQL error
Returns an exception object based on the MySQL error in the given
packet.
Returns an Error-Object.
"""
errno = errmsg = None
try:
if packet[4] != 255:
raise ValueError("Packet is not an error packet")
except IndexError as err:
return InterfaceError("Failed getting Error information (%r)" % err)
sqlstate = None
try:
packet = packet[5:]
(packet, errno) = utils.read_int(packet, 2)
if packet[0] != 35:
# Error without SQLState
if isinstance(packet, (bytes, bytearray)):
errmsg = packet.decode('utf8')
else:
errmsg = packet
else:
(packet, sqlstate) = utils.read_bytes(packet[1:], 5)
sqlstate = sqlstate.decode('utf8')
errmsg = packet.decode('utf8')
except Exception as err: # pylint: disable=W0703
return InterfaceError("Failed getting Error information (%r)" % err)
else:
return get_mysql_exception(errno, errmsg, sqlstate)
class Error(Exception):
"""Exception that is base class for all other error exceptions"""
def __init__(self, msg=None, errno=None, values=None, sqlstate=None):
super(Error, self).__init__()
self.msg = msg
self._full_msg = self.msg
self.errno = errno or -1
self.sqlstate = sqlstate
if not self.msg and (2000 <= self.errno < 3000):
self.msg = get_client_error(self.errno)
if values is not None:
try:
self.msg = self.msg % values
except TypeError as err:
self.msg = "{0} (Warning: {1})".format(self.msg, str(err))
elif not self.msg:
self._full_msg = self.msg = 'Unknown error'
if self.msg and self.errno != -1:
fields = {
'errno': self.errno,
'msg': self.msg
}
if self.sqlstate:
fmt = '{errno} ({state}): {msg}'
fields['state'] = self.sqlstate
else:
fmt = '{errno}: {msg}'
self._full_msg = fmt.format(**fields)
self.args = (self.errno, self._full_msg, self.sqlstate)
def __str__(self):
return self._full_msg
class Warning(Exception): # pylint: disable=W0622
"""Exception for important warnings"""
pass
class InterfaceError(Error):
"""Exception for errors related to the interface"""
pass
class DatabaseError(Error):
"""Exception for errors related to the database"""
pass
class InternalError(DatabaseError):
"""Exception for errors internal database errors"""
pass
class OperationalError(DatabaseError):
"""Exception for errors related to the database's operation"""
pass
class ProgrammingError(DatabaseError):
"""Exception for errors programming errors"""
pass
class IntegrityError(DatabaseError):
"""Exception for errors regarding relational integrity"""
pass
class DataError(DatabaseError):
"""Exception for errors reporting problems with processed data"""
pass
class NotSupportedError(DatabaseError):
"""Exception for errors when an unsupported database feature was used"""
pass
class PoolError(Error):
"""Exception for errors relating to connection pooling"""
pass
_SQLSTATE_CLASS_EXCEPTION = {
'02': DataError, # no data
'07': DatabaseError, # dynamic SQL error
'08': OperationalError, # connection exception
'0A': NotSupportedError, # feature not supported
'21': DataError, # cardinality violation
'22': DataError, # data exception
'23': IntegrityError, # integrity constraint violation
'24': ProgrammingError, # invalid cursor state
'25': ProgrammingError, # invalid transaction state
'26': ProgrammingError, # invalid SQL statement name
'27': ProgrammingError, # triggered data change violation
'28': ProgrammingError, # invalid authorization specification
'2A': ProgrammingError, # direct SQL syntax error or access rule violation
'2B': DatabaseError, # dependent privilege descriptors still exist
'2C': ProgrammingError, # invalid character set name
'2D': DatabaseError, # invalid transaction termination
'2E': DatabaseError, # invalid connection name
'33': DatabaseError, # invalid SQL descriptor name
'34': ProgrammingError, # invalid cursor name
'35': ProgrammingError, # invalid condition number
'37': ProgrammingError, # dynamic SQL syntax error or access rule violation
'3C': ProgrammingError, # ambiguous cursor name
'3D': ProgrammingError, # invalid catalog name
'3F': ProgrammingError, # invalid schema name
'40': InternalError, # transaction rollback
'42': ProgrammingError, # syntax error or access rule violation
'44': InternalError, # with check option violation
'HZ': OperationalError, # remote database access
'XA': IntegrityError,
'0K': OperationalError,
'HY': DatabaseError, # default when no SQLState provided by MySQL server
}
_ERROR_EXCEPTIONS = {
1243: ProgrammingError,
1210: ProgrammingError,
2002: InterfaceError,
2013: OperationalError,
2049: NotSupportedError,
2055: OperationalError,
2061: InterfaceError,
2026: InterfaceError,
}

View File

@@ -0,0 +1,75 @@
# Copyright (c) 2012, 2017, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Translations
"""
__all__ = [
'get_client_error'
]
from .. import errorcode
def get_client_error(error, language='eng'):
"""Lookup client error
This function will lookup the client error message based on the given
error and return the error message. If the error was not found,
None will be returned.
Error can be either an integer or a string. For example:
error: 2000
error: CR_UNKNOWN_ERROR
The language attribute can be used to retrieve a localized message, when
available.
Returns a string or None.
"""
try:
tmp = __import__('mysql.connector.locales.{0}'.format(language),
globals(), locals(), ['client_error'])
except ImportError:
raise ImportError("No localization support for language '{0}'".format(
language))
client_error = tmp.client_error
if isinstance(error, int):
errno = error
for key, value in errorcode.__dict__.items():
if value == errno:
error = key
break
if isinstance(error, (str)):
try:
return getattr(client_error, error)
except AttributeError:
return None
raise ValueError("error argument needs to be either an integer or string")

View File

@@ -0,0 +1,30 @@
# Copyright (c) 2012, 2017, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""English Content
"""

View File

@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2013, 2021, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
# This file was auto-generated.
_GENERATED_ON = '2021-08-11'
_MYSQL_VERSION = (8, 0, 27)
# Start MySQL Error messages
CR_UNKNOWN_ERROR = u"Unknown MySQL error"
CR_SOCKET_CREATE_ERROR = u"Can't create UNIX socket (%s)"
CR_CONNECTION_ERROR = u"Can't connect to local MySQL server through socket '%-.100s' (%s)"
CR_CONN_HOST_ERROR = u"Can't connect to MySQL server on '%-.100s:%u' (%s)"
CR_IPSOCK_ERROR = u"Can't create TCP/IP socket (%s)"
CR_UNKNOWN_HOST = u"Unknown MySQL server host '%-.100s' (%s)"
CR_SERVER_GONE_ERROR = u"MySQL server has gone away"
CR_VERSION_ERROR = u"Protocol mismatch; server version = %s, client version = %s"
CR_OUT_OF_MEMORY = u"MySQL client ran out of memory"
CR_WRONG_HOST_INFO = u"Wrong host info"
CR_LOCALHOST_CONNECTION = u"Localhost via UNIX socket"
CR_TCP_CONNECTION = u"%-.100s via TCP/IP"
CR_SERVER_HANDSHAKE_ERR = u"Error in server handshake"
CR_SERVER_LOST = u"Lost connection to MySQL server during query"
CR_COMMANDS_OUT_OF_SYNC = u"Commands out of sync; you can't run this command now"
CR_NAMEDPIPE_CONNECTION = u"Named pipe: %-.32s"
CR_NAMEDPIPEWAIT_ERROR = u"Can't wait for named pipe to host: %-.64s pipe: %-.32s (%s)"
CR_NAMEDPIPEOPEN_ERROR = u"Can't open named pipe to host: %-.64s pipe: %-.32s (%s)"
CR_NAMEDPIPESETSTATE_ERROR = u"Can't set state of named pipe to host: %-.64s pipe: %-.32s (%s)"
CR_CANT_READ_CHARSET = u"Can't initialize character set %-.32s (path: %-.100s)"
CR_NET_PACKET_TOO_LARGE = u"Got packet bigger than 'max_allowed_packet' bytes"
CR_EMBEDDED_CONNECTION = u"Embedded server"
CR_PROBE_SLAVE_STATUS = u"Error on SHOW SLAVE STATUS:"
CR_PROBE_SLAVE_HOSTS = u"Error on SHOW SLAVE HOSTS:"
CR_PROBE_SLAVE_CONNECT = u"Error connecting to slave:"
CR_PROBE_MASTER_CONNECT = u"Error connecting to master:"
CR_SSL_CONNECTION_ERROR = u"SSL connection error: %-.100s"
CR_MALFORMED_PACKET = u"Malformed packet"
CR_WRONG_LICENSE = u"This client library is licensed only for use with MySQL servers having '%s' license"
CR_NULL_POINTER = u"Invalid use of null pointer"
CR_NO_PREPARE_STMT = u"Statement not prepared"
CR_PARAMS_NOT_BOUND = u"No data supplied for parameters in prepared statement"
CR_DATA_TRUNCATED = u"Data truncated"
CR_NO_PARAMETERS_EXISTS = u"No parameters exist in the statement"
CR_INVALID_PARAMETER_NO = u"Invalid parameter number"
CR_INVALID_BUFFER_USE = u"Can't send long data for non-string/non-binary data types (parameter: %s)"
CR_UNSUPPORTED_PARAM_TYPE = u"Using unsupported buffer type: %s (parameter: %s)"
CR_SHARED_MEMORY_CONNECTION = u"Shared memory: %-.100s"
CR_SHARED_MEMORY_CONNECT_REQUEST_ERROR = u"Can't open shared memory; client could not create request event (%s)"
CR_SHARED_MEMORY_CONNECT_ANSWER_ERROR = u"Can't open shared memory; no answer event received from server (%s)"
CR_SHARED_MEMORY_CONNECT_FILE_MAP_ERROR = u"Can't open shared memory; server could not allocate file mapping (%s)"
CR_SHARED_MEMORY_CONNECT_MAP_ERROR = u"Can't open shared memory; server could not get pointer to file mapping (%s)"
CR_SHARED_MEMORY_FILE_MAP_ERROR = u"Can't open shared memory; client could not allocate file mapping (%s)"
CR_SHARED_MEMORY_MAP_ERROR = u"Can't open shared memory; client could not get pointer to file mapping (%s)"
CR_SHARED_MEMORY_EVENT_ERROR = u"Can't open shared memory; client could not create %s event (%s)"
CR_SHARED_MEMORY_CONNECT_ABANDONED_ERROR = u"Can't open shared memory; no answer from server (%s)"
CR_SHARED_MEMORY_CONNECT_SET_ERROR = u"Can't open shared memory; cannot send request event to server (%s)"
CR_CONN_UNKNOW_PROTOCOL = u"Wrong or unknown protocol"
CR_INVALID_CONN_HANDLE = u"Invalid connection handle"
CR_UNUSED_1 = u"Connection using old (pre-4.1.1) authentication protocol refused (client option 'secure_auth' enabled)"
CR_FETCH_CANCELED = u"Row retrieval was canceled by mysql_stmt_close() call"
CR_NO_DATA = u"Attempt to read column without prior row fetch"
CR_NO_STMT_METADATA = u"Prepared statement contains no metadata"
CR_NO_RESULT_SET = u"Attempt to read a row while there is no result set associated with the statement"
CR_NOT_IMPLEMENTED = u"This feature is not implemented yet"
CR_SERVER_LOST_EXTENDED = u"Lost connection to MySQL server at '%s', system error: %s"
CR_STMT_CLOSED = u"Statement closed indirectly because of a preceding %s() call"
CR_NEW_STMT_METADATA = u"The number of columns in the result set differs from the number of bound buffers. You must reset the statement, rebind the result set columns, and execute the statement again"
CR_ALREADY_CONNECTED = u"This handle is already connected. Use a separate handle for each connection."
CR_AUTH_PLUGIN_CANNOT_LOAD = u"Authentication plugin '%s' cannot be loaded: %s"
CR_DUPLICATE_CONNECTION_ATTR = u"There is an attribute with the same name already"
CR_AUTH_PLUGIN_ERR = u"Authentication plugin '%s' reported error: %s"
CR_INSECURE_API_ERR = u"Insecure API function call: '%s' Use instead: '%s'"
CR_FILE_NAME_TOO_LONG = u"File name is too long"
CR_SSL_FIPS_MODE_ERR = u"Set FIPS mode ON/STRICT failed"
CR_DEPRECATED_COMPRESSION_NOT_SUPPORTED = u"Compression protocol not supported with asynchronous protocol"
CR_COMPRESSION_WRONGLY_CONFIGURED = u"Connection failed due to wrongly configured compression algorithm"
CR_KERBEROS_USER_NOT_FOUND = u"SSO user not found, Please perform SSO authentication using kerberos."
CR_LOAD_DATA_LOCAL_INFILE_REJECTED = u"LOAD DATA LOCAL INFILE file request rejected due to restrictions on access."
CR_LOAD_DATA_LOCAL_INFILE_REALPATH_FAIL = u"Determining the real path for '%s' failed with error (%s): %s"
CR_DNS_SRV_LOOKUP_FAILED = u"DNS SRV lookup failed with error : %s"
CR_MANDATORY_TRACKER_NOT_FOUND = u"Client does not recognise tracker type %s marked as mandatory by server."
CR_INVALID_FACTOR_NO = u"Invalid first argument for MYSQL_OPT_USER_PASSWORD option. Valid value should be between 1 and 3 inclusive."
# End MySQL Error messages

View File

@@ -0,0 +1,584 @@
# Copyright (c) 2012, 2020, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Module implementing low-level socket communication with MySQL servers.
"""
from collections import deque
import os
import socket
import struct
import sys
import zlib
try:
import ssl
TLS_VERSIONS = {
"TLSv1": ssl.PROTOCOL_TLSv1,
"TLSv1.1": ssl.PROTOCOL_TLSv1_1,
"TLSv1.2": ssl.PROTOCOL_TLSv1_2}
# TLSv1.3 included in PROTOCOL_TLS, but PROTOCOL_TLS is not included on 3.4
if hasattr(ssl, "PROTOCOL_TLS"):
TLS_VERSIONS["TLSv1.3"] = ssl.PROTOCOL_TLS # pylint: disable=E1101
else:
TLS_VERSIONS["TLSv1.3"] = ssl.PROTOCOL_SSLv23 # Alias of PROTOCOL_TLS
if hasattr(ssl, "HAS_TLSv1_3") and ssl.HAS_TLSv1_3:
TLS_V1_3_SUPPORTED = True
else:
TLS_V1_3_SUPPORTED = False
except:
# If import fails, we don't have SSL support.
TLS_V1_3_SUPPORTED = False
pass
from . import constants, errors
from .errors import InterfaceError
from .utils import init_bytearray
def _strioerror(err):
"""Reformat the IOError error message
This function reformats the IOError error message.
"""
if not err.errno:
return str(err)
return '{errno} {strerr}'.format(errno=err.errno, strerr=err.strerror)
def _prepare_packets(buf, pktnr):
"""Prepare a packet for sending to the MySQL server"""
pkts = []
pllen = len(buf)
maxpktlen = constants.MAX_PACKET_LENGTH
while pllen > maxpktlen:
pkts.append(b'\xff\xff\xff' + struct.pack('<B', pktnr)
+ buf[:maxpktlen])
buf = buf[maxpktlen:]
pllen = len(buf)
pktnr = pktnr + 1
pkts.append(struct.pack('<I', pllen)[0:3]
+ struct.pack('<B', pktnr) + buf)
return pkts
class BaseMySQLSocket(object):
"""Base class for MySQL socket communication
This class should not be used directly but overloaded, changing the
at least the open_connection()-method. Examples of subclasses are
mysql.connector.network.MySQLTCPSocket
mysql.connector.network.MySQLUnixSocket
"""
def __init__(self):
self.sock = None # holds the socket connection
self._connection_timeout = None
self._packet_number = -1
self._compressed_packet_number = -1
self._packet_queue = deque()
self.recvsize = 8192
@property
def next_packet_number(self):
"""Increments the packet number"""
self._packet_number = self._packet_number + 1
if self._packet_number > 255:
self._packet_number = 0
return self._packet_number
@property
def next_compressed_packet_number(self):
"""Increments the compressed packet number"""
self._compressed_packet_number = self._compressed_packet_number + 1
if self._compressed_packet_number > 255:
self._compressed_packet_number = 0
return self._compressed_packet_number
def open_connection(self):
"""Open the socket"""
raise NotImplementedError
def get_address(self):
"""Get the location of the socket"""
raise NotImplementedError
def shutdown(self):
"""Shut down the socket before closing it"""
try:
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
del self._packet_queue
except (socket.error, AttributeError):
pass
def close_connection(self):
"""Close the socket"""
try:
self.sock.close()
del self._packet_queue
except (socket.error, AttributeError):
pass
def __del__(self):
self.shutdown()
def send_plain(self, buf, packet_number=None,
compressed_packet_number=None):
"""Send packets to the MySQL server"""
if packet_number is None:
self.next_packet_number # pylint: disable=W0104
else:
self._packet_number = packet_number
packets = _prepare_packets(buf, self._packet_number)
for packet in packets:
try:
self.sock.sendall(packet)
except IOError as err:
raise errors.OperationalError(
errno=2055, values=(self.get_address(), _strioerror(err)))
except AttributeError:
raise errors.OperationalError(errno=2006)
send = send_plain
def send_compressed(self, buf, packet_number=None,
compressed_packet_number=None):
"""Send compressed packets to the MySQL server"""
if packet_number is None:
self.next_packet_number # pylint: disable=W0104
else:
self._packet_number = packet_number
if compressed_packet_number is None:
self.next_compressed_packet_number # pylint: disable=W0104
else:
self._compressed_packet_number = compressed_packet_number
pktnr = self._packet_number
pllen = len(buf)
zpkts = []
maxpktlen = constants.MAX_PACKET_LENGTH
if pllen > maxpktlen:
pkts = _prepare_packets(buf, pktnr)
tmpbuf = b''.join(pkts)
del pkts
zbuf = zlib.compress(tmpbuf[:16384])
header = (struct.pack('<I', len(zbuf))[0:3]
+ struct.pack('<B', self._compressed_packet_number)
+ b'\x00\x40\x00')
zpkts.append(header + zbuf)
tmpbuf = tmpbuf[16384:]
pllen = len(tmpbuf)
self.next_compressed_packet_number # pylint: disable=W0104
while pllen > maxpktlen:
zbuf = zlib.compress(tmpbuf[:maxpktlen])
header = (struct.pack('<I', len(zbuf))[0:3]
+ struct.pack('<B', self._compressed_packet_number)
+ b'\xff\xff\xff')
zpkts.append(header + zbuf)
tmpbuf = tmpbuf[maxpktlen:]
pllen = len(tmpbuf)
self.next_compressed_packet_number # pylint: disable=W0104
if tmpbuf:
zbuf = zlib.compress(tmpbuf)
header = (struct.pack('<I', len(zbuf))[0:3]
+ struct.pack('<B', self._compressed_packet_number)
+ struct.pack('<I', pllen)[0:3])
zpkts.append(header + zbuf)
del tmpbuf
else:
pkt = (struct.pack('<I', pllen)[0:3] +
struct.pack('<B', pktnr) + buf)
pllen = len(pkt)
if pllen > 50:
zbuf = zlib.compress(pkt)
zpkts.append(struct.pack('<I', len(zbuf))[0:3]
+ struct.pack('<B', self._compressed_packet_number)
+ struct.pack('<I', pllen)[0:3]
+ zbuf)
else:
header = (struct.pack('<I', pllen)[0:3]
+ struct.pack('<B', self._compressed_packet_number)
+ struct.pack('<I', 0)[0:3])
zpkts.append(header + pkt)
for zip_packet in zpkts:
try:
self.sock.sendall(zip_packet)
except IOError as err:
raise errors.OperationalError(
errno=2055, values=(self.get_address(), _strioerror(err)))
except AttributeError:
raise errors.OperationalError(errno=2006)
def recv_plain(self):
"""Receive packets from the MySQL server"""
try:
# Read the header of the MySQL packet, 4 bytes
packet = bytearray(b'')
packet_len = 0
while packet_len < 4:
chunk = self.sock.recv(4 - packet_len)
if not chunk:
raise errors.InterfaceError(errno=2013)
packet += chunk
packet_len = len(packet)
# Save the packet number and payload length
self._packet_number = packet[3]
payload_len = struct.unpack("<I", packet[0:3] + b'\x00')[0]
# Read the payload
rest = payload_len
packet.extend(bytearray(payload_len))
packet_view = memoryview(packet) # pylint: disable=E0602
packet_view = packet_view[4:]
while rest:
read = self.sock.recv_into(packet_view, rest)
if read == 0 and rest > 0:
raise errors.InterfaceError(errno=2013)
packet_view = packet_view[read:]
rest -= read
return packet
except IOError as err:
raise errors.OperationalError(
errno=2055, values=(self.get_address(), _strioerror(err)))
def recv_py26_plain(self):
"""Receive packets from the MySQL server"""
try:
# Read the header of the MySQL packet, 4 bytes
header = bytearray(b'')
header_len = 0
while header_len < 4:
chunk = self.sock.recv(4 - header_len)
if not chunk:
raise errors.InterfaceError(errno=2013)
header += chunk
header_len = len(header)
# Save the packet number and payload length
self._packet_number = header[3]
payload_len = struct.unpack("<I", header[0:3] + b'\x00')[0]
# Read the payload
rest = payload_len
payload = init_bytearray(b'')
while rest > 0:
chunk = self.sock.recv(rest)
if not chunk:
raise errors.InterfaceError(errno=2013)
payload += chunk
rest = payload_len - len(payload)
return header + payload
except IOError as err:
raise errors.OperationalError(
errno=2055, values=(self.get_address(), _strioerror(err)))
if sys.version_info[0:2] == (2, 6):
recv = recv_py26_plain
recv_plain = recv_py26_plain
else:
recv = recv_plain
def _split_zipped_payload(self, packet_bunch):
"""Split compressed payload"""
while packet_bunch:
payload_length = struct.unpack("<I", packet_bunch[0:3] + b'\x00')[0]
self._packet_queue.append(packet_bunch[0:payload_length + 4])
packet_bunch = packet_bunch[payload_length + 4:]
def recv_compressed(self):
"""Receive compressed packets from the MySQL server"""
try:
pkt = self._packet_queue.popleft()
self._packet_number = pkt[3]
return pkt
except IndexError:
pass
header = bytearray(b'')
packets = []
try:
abyte = self.sock.recv(1)
while abyte and len(header) < 7:
header += abyte
abyte = self.sock.recv(1)
while header:
if len(header) < 7:
raise errors.InterfaceError(errno=2013)
# Get length of compressed packet
zip_payload_length = struct.unpack("<I",
header[0:3] + b'\x00')[0]
self._compressed_packet_number = header[3]
# Get payload length before compression
payload_length = struct.unpack("<I", header[4:7] + b'\x00')[0]
zip_payload = init_bytearray(abyte)
while len(zip_payload) < zip_payload_length:
chunk = self.sock.recv(zip_payload_length
- len(zip_payload))
if not chunk:
raise errors.InterfaceError(errno=2013)
zip_payload = zip_payload + chunk
# Payload was not compressed
if payload_length == 0:
self._split_zipped_payload(zip_payload)
pkt = self._packet_queue.popleft()
self._packet_number = pkt[3]
return pkt
packets.append((payload_length, zip_payload))
if zip_payload_length <= 16384:
# We received the full compressed packet
break
# Get next compressed packet
header = init_bytearray(b'')
abyte = self.sock.recv(1)
while abyte and len(header) < 7:
header += abyte
abyte = self.sock.recv(1)
except IOError as err:
raise errors.OperationalError(
errno=2055, values=(self.get_address(), _strioerror(err)))
# Compressed packet can contain more than 1 MySQL packets
# We decompress and make one so we can split it up
tmp = init_bytearray(b'')
for payload_length, payload in packets:
# payload_length can not be 0; this was previously handled
tmp += zlib.decompress(payload)
self._split_zipped_payload(tmp)
del tmp
try:
pkt = self._packet_queue.popleft()
self._packet_number = pkt[3]
return pkt
except IndexError:
pass
def set_connection_timeout(self, timeout):
"""Set the connection timeout"""
self._connection_timeout = timeout
if self.sock:
self.sock.settimeout(timeout)
# pylint: disable=C0103,E1101
def switch_to_ssl(self, ca, cert, key, verify_cert=False,
verify_identity=False, cipher_suites=None,
tls_versions=None):
"""Switch the socket to use SSL"""
if not self.sock:
raise errors.InterfaceError(errno=2048)
try:
if verify_cert:
cert_reqs = ssl.CERT_REQUIRED
elif verify_identity:
cert_reqs = ssl.CERT_OPTIONAL
else:
cert_reqs = ssl.CERT_NONE
if tls_versions is None or not tls_versions:
context = ssl.create_default_context()
if not verify_identity:
context.check_hostname = False
else:
tls_versions.sort(reverse=True)
tls_version = tls_versions[0]
if not TLS_V1_3_SUPPORTED and \
tls_version == "TLSv1.3" and len(tls_versions) > 1:
tls_version = tls_versions[1]
ssl_protocol = TLS_VERSIONS[tls_version]
context = ssl.SSLContext(ssl_protocol)
if tls_version == "TLSv1.3":
if "TLSv1.2" not in tls_versions:
context.options |= ssl.OP_NO_TLSv1_2
if "TLSv1.1" not in tls_versions:
context.options |= ssl.OP_NO_TLSv1_1
if "TLSv1" not in tls_versions:
context.options |= ssl.OP_NO_TLSv1
context.check_hostname = False
context.verify_mode = cert_reqs
context.load_default_certs()
if ca:
try:
context.load_verify_locations(ca)
except (IOError, ssl.SSLError) as err:
self.sock.close()
raise InterfaceError(
"Invalid CA Certificate: {}".format(err))
if cert:
try:
context.load_cert_chain(cert, key)
except (IOError, ssl.SSLError) as err:
self.sock.close()
raise InterfaceError(
"Invalid Certificate/Key: {}".format(err))
if cipher_suites:
context.set_ciphers(cipher_suites)
if hasattr(self, "server_host"):
self.sock = context.wrap_socket(
self.sock, server_hostname=self.server_host)
else:
self.sock = context.wrap_socket(self.sock)
if verify_identity:
context.check_hostname = True
hostnames = [self.server_host]
if os.name == 'nt' and self.server_host == 'localhost':
hostnames = ['localhost', '127.0.0.1']
aliases = socket.gethostbyaddr(self.server_host)
hostnames.extend([aliases[0]] + aliases[1])
match_found = False
errs = []
for hostname in hostnames:
try:
ssl.match_hostname(self.sock.getpeercert(), hostname)
except ssl.CertificateError as err:
errs.append(str(err))
else:
match_found = True
break
if not match_found:
self.sock.close()
raise InterfaceError("Unable to verify server identity: {}"
"".format(", ".join(errs)))
except NameError:
raise errors.NotSupportedError(
"Python installation has no SSL support")
except (ssl.SSLError, IOError) as err:
raise errors.InterfaceError(
errno=2055, values=(self.get_address(), _strioerror(err)))
except ssl.CertificateError as err:
raise errors.InterfaceError(str(err))
except NotImplementedError as err:
raise errors.InterfaceError(str(err))
# pylint: enable=C0103,E1101
class MySQLUnixSocket(BaseMySQLSocket):
"""MySQL socket class using UNIX sockets
Opens a connection through the UNIX socket of the MySQL Server.
"""
def __init__(self, unix_socket='/tmp/mysql.sock'):
super(MySQLUnixSocket, self).__init__()
self.unix_socket = unix_socket
def get_address(self):
return self.unix_socket
def open_connection(self):
try:
self.sock = socket.socket(socket.AF_UNIX, # pylint: disable=E1101
socket.SOCK_STREAM)
self.sock.settimeout(self._connection_timeout)
self.sock.connect(self.unix_socket)
except IOError as err:
raise errors.InterfaceError(
errno=2002, values=(self.get_address(), _strioerror(err)))
except Exception as err:
raise errors.InterfaceError(str(err))
class MySQLTCPSocket(BaseMySQLSocket):
"""MySQL socket class using TCP/IP
Opens a TCP/IP connection to the MySQL Server.
"""
def __init__(self, host='127.0.0.1', port=3306, force_ipv6=False):
super(MySQLTCPSocket, self).__init__()
self.server_host = host
self.server_port = port
self.force_ipv6 = force_ipv6
self._family = 0
def get_address(self):
return "{0}:{1}".format(self.server_host, self.server_port)
def open_connection(self):
"""Open the TCP/IP connection to the MySQL server
"""
# Get address information
addrinfo = [None] * 5
try:
addrinfos = socket.getaddrinfo(self.server_host,
self.server_port,
0, socket.SOCK_STREAM,
socket.SOL_TCP)
# If multiple results we favor IPv4, unless IPv6 was forced.
for info in addrinfos:
if self.force_ipv6 and info[0] == socket.AF_INET6:
addrinfo = info
break
elif info[0] == socket.AF_INET:
addrinfo = info
break
if self.force_ipv6 and addrinfo[0] is None:
raise errors.InterfaceError(
"No IPv6 address found for {0}".format(self.server_host))
if addrinfo[0] is None:
addrinfo = addrinfos[0]
except IOError as err:
raise errors.InterfaceError(
errno=2003, values=(self.get_address(), _strioerror(err)))
else:
(self._family, socktype, proto, _, sockaddr) = addrinfo
# Instanciate the socket and connect
try:
self.sock = socket.socket(self._family, socktype, proto)
self.sock.settimeout(self._connection_timeout)
self.sock.connect(sockaddr)
except IOError as err:
raise errors.InterfaceError(
errno=2003, values=(
self.server_host,
self.server_port,
_strioerror(err),
)
)
except Exception as err:
raise errors.OperationalError(str(err))

View File

@@ -0,0 +1,345 @@
# Copyright (c) 2014, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implements parser to parse MySQL option files.
"""
import codecs
import io
import os
import re
from configparser import (
ConfigParser as SafeConfigParser,
MissingSectionHeaderError
)
from .constants import DEFAULT_CONFIGURATION, CNX_POOL_ARGS
DEFAULT_EXTENSIONS = {
'nt': ('ini', 'cnf'),
'posix': ('cnf',)
}
def read_option_files(**config):
"""
Read option files for connection parameters.
Checks if connection arguments contain option file arguments, and then
reads option files accordingly.
"""
if 'option_files' in config:
try:
if isinstance(config['option_groups'], str):
config['option_groups'] = [config['option_groups']]
groups = config['option_groups']
del config['option_groups']
except KeyError:
groups = ['client', 'connector_python']
if isinstance(config['option_files'], str):
config['option_files'] = [config['option_files']]
option_parser = MySQLOptionsParser(list(config['option_files']),
keep_dashes=False)
del config['option_files']
config_from_file = option_parser.get_groups_as_dict_with_priority(
*groups)
config_options = {}
for group in groups:
try:
for option, value in config_from_file[group].items():
try:
if option == 'socket':
option = 'unix_socket'
if (option not in CNX_POOL_ARGS and
option != 'failover'):
# pylint: disable=W0104
DEFAULT_CONFIGURATION[option]
# pylint: enable=W0104
if (option not in config_options or
config_options[option][1] <= value[1]):
config_options[option] = value
except KeyError:
if group == 'connector_python':
raise AttributeError("Unsupported argument "
"'{0}'".format(option))
except KeyError:
continue
not_evaluate = ('password', 'passwd')
for option, value in config_options.items():
if option not in config:
try:
if option in not_evaluate:
config[option] = value[0]
else:
config[option] = eval(value[0]) # pylint: disable=W0123
except (NameError, SyntaxError):
config[option] = value[0]
return config
class MySQLOptionsParser(SafeConfigParser): # pylint: disable=R0901
"""This class implements methods to parse MySQL option files"""
def __init__(self, files=None, keep_dashes=True): # pylint: disable=W0231
"""Initialize
If defaults is True, default option files are read first
Raises ValueError if defaults is set to True but defaults files
cannot be found.
"""
# Regular expression to allow options with no value(For Python v2.6)
self.OPTCRE = re.compile( # pylint: disable=C0103
r'(?P<option>[^:=\s][^:=]*)'
r'\s*(?:'
r'(?P<vi>[:=])\s*'
r'(?P<value>.*))?$'
)
self._options_dict = {}
SafeConfigParser.__init__(self, strict=False)
self.default_extension = DEFAULT_EXTENSIONS[os.name]
self.keep_dashes = keep_dashes
if not files:
raise ValueError('files argument should be given')
if isinstance(files, str):
self.files = [files]
else:
self.files = files
self._parse_options(list(self.files))
self._sections = self.get_groups_as_dict()
def optionxform(self, optionstr):
"""Converts option strings
Converts option strings to lower case and replaces dashes(-) with
underscores(_) if keep_dashes variable is set.
"""
if not self.keep_dashes:
optionstr = optionstr.replace('-', '_')
return optionstr.lower()
def _parse_options(self, files):
"""Parse options from files given as arguments.
This method checks for !include or !inculdedir directives and if there
is any, those files included by these directives are also parsed
for options.
Raises ValueError if any of the included or file given in arguments
is not readable.
"""
initial_files = files[:]
files = []
index = 0
err_msg = "Option file '{0}' being included again in file '{1}'"
for file_ in initial_files:
try:
if file_ in initial_files[index+1:]:
raise ValueError("Same option file '{0}' occurring more "
"than once in the list".format(file_))
with open(file_, 'r') as op_file:
for line in op_file.readlines():
if line.startswith('!includedir'):
_, dir_path = line.split(None, 1)
dir_path = dir_path.strip()
for entry in os.listdir(dir_path):
entry = os.path.join(dir_path, entry)
if entry in files:
raise ValueError(err_msg.format(
entry, file_))
if (os.path.isfile(entry) and
entry.endswith(self.default_extension)):
files.append(entry)
elif line.startswith('!include'):
_, filename = line.split(None, 1)
filename = filename.strip()
if filename in files:
raise ValueError(err_msg.format(
filename, file_))
files.append(filename)
index += 1
files.append(file_)
except (IOError, OSError) as exc:
raise ValueError("Failed reading file '{0}': {1}".format(
file_, str(exc)))
read_files = self.read(files)
not_read_files = set(files) - set(read_files)
if not_read_files:
raise ValueError("File(s) {0} could not be read.".format(
', '.join(not_read_files)))
def read(self, filenames): # pylint: disable=W0221
"""Read and parse a filename or a list of filenames.
Overridden from ConfigParser and modified so as to allow options
which are not inside any section header
Return list of successfully read files.
"""
if isinstance(filenames, str):
filenames = [filenames]
read_ok = []
for priority, filename in enumerate(filenames):
try:
out_file = io.StringIO()
for line in codecs.open(filename, encoding='utf-8'):
line = line.strip()
# Skip lines that begin with "!includedir" or "!include"
if line.startswith('!include'):
continue
match_obj = self.OPTCRE.match(line)
if not self.SECTCRE.match(line) and match_obj:
optname, delimiter, optval = match_obj.group('option',
'vi',
'value')
if optname and not optval and not delimiter:
out_file.write(line + "=\n")
else:
out_file.write(line + '\n')
else:
out_file.write(line + '\n')
out_file.seek(0)
except IOError:
continue
try:
self._read(out_file, filename)
for group in self._sections.keys():
try:
self._options_dict[group]
except KeyError:
self._options_dict[group] = {}
for option, value in self._sections[group].items():
self._options_dict[group][option] = (value, priority)
self._sections = self._dict()
except MissingSectionHeaderError:
self._read(out_file, filename)
out_file.close()
read_ok.append(filename)
return read_ok
def get_groups(self, *args):
"""Returns options as a dictionary.
Returns options from all the groups specified as arguments, returns
the options from all groups if no argument provided. Options are
overridden when they are found in the next group.
Returns a dictionary
"""
if not args:
args = self._options_dict.keys()
options = {}
priority = {}
for group in args:
try:
for option, value in [(key, value,) for key, value in
self._options_dict[group].items() if
key != "__name__" and
not key.startswith("!")]:
if option not in options or priority[option] <= value[1]:
priority[option] = value[1]
options[option] = value[0]
except KeyError:
pass
return options
def get_groups_as_dict_with_priority(self, *args): # pylint: disable=C0103
"""Returns options as dictionary of dictionaries.
Returns options from all the groups specified as arguments. For each
group the option are contained in a dictionary. The order in which
the groups are specified is unimportant. Also options are not
overridden in between the groups.
The value is a tuple with two elements, first being the actual value
and second is the priority of the value which is higher for a value
read from a higher priority file.
Returns an dictionary of dictionaries
"""
if not args:
args = self._options_dict.keys()
options = dict()
for group in args:
try:
options[group] = dict((key, value,) for key, value in
self._options_dict[group].items() if
key != "__name__" and
not key.startswith("!"))
except KeyError:
pass
return options
def get_groups_as_dict(self, *args):
"""Returns options as dictionary of dictionaries.
Returns options from all the groups specified as arguments. For each
group the option are contained in a dictionary. The order in which
the groups are specified is unimportant. Also options are not
overridden in between the groups.
Returns an dictionary of dictionaries
"""
if not args:
args = self._options_dict.keys()
options = dict()
for group in args:
try:
options[group] = dict((key, value[0],) for key, value in
self._options_dict[group].items() if
key != "__name__" and
not key.startswith("!"))
except KeyError:
pass
return options

View File

@@ -0,0 +1,373 @@
# Copyright (c) 2013, 2021, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implementing pooling of connections to MySQL servers.
"""
import re
from uuid import uuid4
# pylint: disable=F0401
try:
import queue
except ImportError:
# Python v2
import Queue as queue
# pylint: enable=F0401
import threading
try:
from mysql.connector.connection_cext import (CMySQLConnection)
except ImportError:
CMySQLConnection = None
from . import errors
from . import Connect
from .connection import MySQLConnection
CONNECTION_POOL_LOCK = threading.RLock()
CNX_POOL_MAXSIZE = 32
CNX_POOL_MAXNAMESIZE = 64
CNX_POOL_NAMEREGEX = re.compile(r'[^a-zA-Z0-9._:\-*$#]')
MYSQL_CNX_CLASS = ((MySQLConnection) if CMySQLConnection is None else
(MySQLConnection, CMySQLConnection))
def generate_pool_name(**kwargs):
"""Generate a pool name
This function takes keyword arguments, usually the connection
arguments for MySQLConnection, and tries to generate a name for
a pool.
Raises PoolError when no name can be generated.
Returns a string.
"""
parts = []
for key in ('host', 'port', 'user', 'database'):
try:
parts.append(str(kwargs[key]))
except KeyError:
pass
if not parts:
raise errors.PoolError(
"Failed generating pool name; specify pool_name")
return '_'.join(parts)
class PooledMySQLConnection(object):
"""Class holding a MySQL Connection in a pool
PooledMySQLConnection is used by MySQLConnectionPool to return an
instance holding a MySQL connection. It works like a MySQLConnection
except for methods like close() and config().
The close()-method will add the connection back to the pool rather
than disconnecting from the MySQL server.
Configuring the connection have to be done through the MySQLConnectionPool
method set_config(). Using config() on pooled connection will raise a
PoolError.
"""
def __init__(self, pool, cnx):
"""Initialize
The pool argument must be an instance of MySQLConnectionPoll. cnx
if an instance of MySQLConnection.
"""
if not isinstance(pool, MySQLConnectionPool):
raise AttributeError(
"pool should be a MySQLConnectionPool")
if not isinstance(cnx, MYSQL_CNX_CLASS):
raise AttributeError(
"cnx should be a MySQLConnection")
self._cnx_pool = pool
self._cnx = cnx
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __getattr__(self, attr):
"""Calls attributes of the MySQLConnection instance"""
return getattr(self._cnx, attr)
def close(self):
"""Do not close, but add connection back to pool
The close() method does not close the connection with the
MySQL server. The connection is added back to the pool so it
can be reused.
When the pool is configured to reset the session, the session
state will be cleared by re-authenticating the user.
"""
try:
cnx = self._cnx
if self._cnx_pool.reset_session:
cnx.reset_session()
finally:
self._cnx_pool.add_connection(cnx)
self._cnx = None
def config(self, **kwargs):
"""Configuration is done through the pool"""
raise errors.PoolError(
"Configuration for pooled connections should "
"be done through the pool itself."
)
@property
def pool_name(self):
"""Return the name of the connection pool"""
return self._cnx_pool.pool_name
class MySQLConnectionPool(object):
"""Class defining a pool of MySQL connections"""
def __init__(self, pool_size=5, pool_name=None, pool_reset_session=True,
**kwargs):
"""Initialize
Initialize a MySQL connection pool with a maximum number of
connections set to pool_size. The rest of the keywords
arguments, kwargs, are configuration arguments for MySQLConnection
instances.
"""
self._pool_size = None
self._pool_name = None
self._reset_session = pool_reset_session
self._set_pool_size(pool_size)
self._set_pool_name(pool_name or generate_pool_name(**kwargs))
self._cnx_config = {}
self._cnx_queue = queue.Queue(self._pool_size)
self._config_version = uuid4()
if kwargs:
self.set_config(**kwargs)
cnt = 0
while cnt < self._pool_size:
self.add_connection()
cnt += 1
@property
def pool_name(self):
"""Return the name of the connection pool"""
return self._pool_name
@property
def pool_size(self):
"""Return number of connections managed by the pool"""
return self._pool_size
@property
def reset_session(self):
"""Return whether to reset session"""
return self._reset_session
def set_config(self, **kwargs):
"""Set the connection configuration for MySQLConnection instances
This method sets the configuration used for creating MySQLConnection
instances. See MySQLConnection for valid connection arguments.
Raises PoolError when a connection argument is not valid, missing
or not supported by MySQLConnection.
"""
if not kwargs:
return
with CONNECTION_POOL_LOCK:
try:
test_cnx = Connect()
test_cnx.config(**kwargs)
self._cnx_config = kwargs
self._config_version = uuid4()
except AttributeError as err:
raise errors.PoolError(
"Connection configuration not valid: {0}".format(err))
def _set_pool_size(self, pool_size):
"""Set the size of the pool
This method sets the size of the pool but it will not resize the pool.
Raises an AttributeError when the pool_size is not valid. Invalid size
is 0, negative or higher than pooling.CNX_POOL_MAXSIZE.
"""
if pool_size <= 0 or pool_size > CNX_POOL_MAXSIZE:
raise AttributeError(
"Pool size should be higher than 0 and "
"lower or equal to {0}".format(CNX_POOL_MAXSIZE))
self._pool_size = pool_size
def _set_pool_name(self, pool_name):
r"""Set the name of the pool
This method checks the validity and sets the name of the pool.
Raises an AttributeError when pool_name contains illegal characters
([^a-zA-Z0-9._\-*$#]) or is longer than pooling.CNX_POOL_MAXNAMESIZE.
"""
if CNX_POOL_NAMEREGEX.search(pool_name):
raise AttributeError(
"Pool name '{0}' contains illegal characters".format(pool_name))
if len(pool_name) > CNX_POOL_MAXNAMESIZE:
raise AttributeError(
"Pool name '{0}' is too long".format(pool_name))
self._pool_name = pool_name
def _queue_connection(self, cnx):
"""Put connection back in the queue
This method is putting a connection back in the queue. It will not
acquire a lock as the methods using _queue_connection() will have it
set.
Raises PoolError on errors.
"""
if not isinstance(cnx, MYSQL_CNX_CLASS):
raise errors.PoolError(
"Connection instance not subclass of MySQLConnection.")
try:
self._cnx_queue.put(cnx, block=False)
except queue.Full:
raise errors.PoolError("Failed adding connection; queue is full")
def add_connection(self, cnx=None):
"""Add a connection to the pool
This method instantiates a MySQLConnection using the configuration
passed when initializing the MySQLConnectionPool instance or using
the set_config() method.
If cnx is a MySQLConnection instance, it will be added to the
queue.
Raises PoolError when no configuration is set, when no more
connection can be added (maximum reached) or when the connection
can not be instantiated.
"""
with CONNECTION_POOL_LOCK:
if not self._cnx_config:
raise errors.PoolError(
"Connection configuration not available")
if self._cnx_queue.full():
raise errors.PoolError(
"Failed adding connection; queue is full")
if not cnx:
cnx = Connect(**self._cnx_config)
try:
if (self._reset_session and self._cnx_config['compress']
and cnx.get_server_version() < (5, 7, 3)):
raise errors.NotSupportedError("Pool reset session is "
"not supported with "
"compression for MySQL "
"server version 5.7.2 "
"or earlier.")
except KeyError:
pass
# pylint: disable=W0201,W0212
cnx._pool_config_version = self._config_version
# pylint: enable=W0201,W0212
else:
if not isinstance(cnx, MYSQL_CNX_CLASS):
raise errors.PoolError(
"Connection instance not subclass of MySQLConnection.")
self._queue_connection(cnx)
def get_connection(self):
"""Get a connection from the pool
This method returns an PooledMySQLConnection instance which
has a reference to the pool that created it, and the next available
MySQL connection.
When the MySQL connection is not connect, a reconnect is attempted.
Raises PoolError on errors.
Returns a PooledMySQLConnection instance.
"""
with CONNECTION_POOL_LOCK:
try:
cnx = self._cnx_queue.get(block=False)
except queue.Empty:
raise errors.PoolError(
"Failed getting connection; pool exhausted")
# pylint: disable=W0201,W0212
if not cnx.is_connected() \
or self._config_version != cnx._pool_config_version:
cnx.config(**self._cnx_config)
try:
cnx.reconnect()
except errors.InterfaceError:
# Failed to reconnect, give connection back to pool
self._queue_connection(cnx)
raise
cnx._pool_config_version = self._config_version
# pylint: enable=W0201,W0212
return PooledMySQLConnection(self, cnx)
def _remove_connections(self):
"""Close all connections
This method closes all connections. It returns the number
of connections it closed.
Used mostly for tests.
Returns int.
"""
with CONNECTION_POOL_LOCK:
cnt = 0
cnxq = self._cnx_queue
while cnxq.qsize():
try:
cnx = cnxq.get(block=False)
cnx.disconnect()
cnt += 1
except queue.Empty:
return cnt
except errors.PoolError:
raise
except errors.Error:
# Any other error when closing means connection is closed
pass
return cnt

View File

@@ -0,0 +1,818 @@
# Copyright (c) 2009, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implements the MySQL Client/Server protocol
"""
import struct
import datetime
from decimal import Decimal
from .constants import (
FieldFlag, ServerCmd, FieldType, ClientFlag, PARAMETER_COUNT_AVAILABLE)
from . import errors, utils
from .authentication import get_auth_plugin
from .errors import DatabaseError, get_exception
PROTOCOL_VERSION = 10
class MySQLProtocol(object):
"""Implements MySQL client/server protocol
Create and parses MySQL packets.
"""
def _connect_with_db(self, client_flags, database):
"""Prepare database string for handshake response"""
if client_flags & ClientFlag.CONNECT_WITH_DB and database:
return database.encode('utf8') + b'\x00'
return b'\x00'
def _auth_response(self, client_flags, username, password, database,
auth_plugin, auth_data, ssl_enabled):
"""Prepare the authentication response"""
if not password:
return b'\x00'
try:
auth = get_auth_plugin(auth_plugin)(
auth_data,
username=username, password=password, database=database,
ssl_enabled=ssl_enabled)
plugin_auth_response = auth.auth_response()
except (TypeError, errors.InterfaceError) as exc:
raise errors.InterfaceError(
"Failed authentication: {0}".format(str(exc)))
if client_flags & ClientFlag.SECURE_CONNECTION:
resplen = len(plugin_auth_response)
auth_response = struct.pack('<B', resplen) + plugin_auth_response
else:
auth_response = plugin_auth_response + b'\x00'
return auth_response
def make_auth(self, handshake, username=None, password=None, database=None,
charset=45, client_flags=0,
max_allowed_packet=1073741824, ssl_enabled=False,
auth_plugin=None, conn_attrs=None):
"""Make a MySQL Authentication packet"""
try:
auth_data = handshake['auth_data']
auth_plugin = auth_plugin or handshake['auth_plugin']
except (TypeError, KeyError) as exc:
raise errors.ProgrammingError(
"Handshake misses authentication info ({0})".format(exc))
if not username:
username = b''
try:
username_bytes = username.encode('utf8') # pylint: disable=E1103
except AttributeError:
# Username is already bytes
username_bytes = username
packet = struct.pack('<IIH{filler}{usrlen}sx'.format(
filler='x' * 22, usrlen=len(username_bytes)),
client_flags, max_allowed_packet, charset,
username_bytes)
packet += self._auth_response(client_flags, username, password,
database,
auth_plugin,
auth_data, ssl_enabled)
packet += self._connect_with_db(client_flags, database)
if client_flags & ClientFlag.PLUGIN_AUTH:
packet += auth_plugin.encode('utf8') + b'\x00'
if (client_flags & ClientFlag.CONNECT_ARGS) and conn_attrs is not None:
packet += self.make_conn_attrs(conn_attrs)
return packet
def make_conn_attrs(self, conn_attrs):
"""Encode the connection attributes"""
for attr_name in conn_attrs:
if conn_attrs[attr_name] is None:
conn_attrs[attr_name] = ""
conn_attrs_len = (
sum([len(x) + len(conn_attrs[x]) for x in conn_attrs]) +
len(conn_attrs.keys()) + len(conn_attrs.values()))
conn_attrs_packet = struct.pack('<B', conn_attrs_len)
for attr_name in conn_attrs:
conn_attrs_packet += struct.pack('<B', len(attr_name))
conn_attrs_packet += attr_name.encode('utf8')
conn_attrs_packet += struct.pack('<B', len(conn_attrs[attr_name]))
conn_attrs_packet += conn_attrs[attr_name].encode('utf8')
return conn_attrs_packet
def make_auth_ssl(self, charset=45, client_flags=0,
max_allowed_packet=1073741824):
"""Make a SSL authentication packet"""
return utils.int4store(client_flags) + \
utils.int4store(max_allowed_packet) + \
utils.int2store(charset) + \
b'\x00' * 22
def make_command(self, command, argument=None):
"""Make a MySQL packet containing a command"""
data = utils.int1store(command)
if argument is not None:
data += argument
return data
def make_stmt_fetch(self, statement_id, rows=1):
"""Make a MySQL packet with Fetch Statement command"""
return utils.int4store(statement_id) + utils.int4store(rows)
def make_change_user(self, handshake, username=None, password=None,
database=None, charset=45, client_flags=0,
ssl_enabled=False, auth_plugin=None, conn_attrs=None):
"""Make a MySQL packet with the Change User command"""
try:
auth_data = handshake['auth_data']
auth_plugin = auth_plugin or handshake['auth_plugin']
except (TypeError, KeyError) as exc:
raise errors.ProgrammingError(
"Handshake misses authentication info ({0})".format(exc))
if not username:
username = b''
try:
username_bytes = username.encode('utf8') # pylint: disable=E1103
except AttributeError:
# Username is already bytes
username_bytes = username
packet = struct.pack('<B{usrlen}sx'.format(usrlen=len(username_bytes)),
ServerCmd.CHANGE_USER, username_bytes)
packet += self._auth_response(client_flags, username, password,
database,
auth_plugin,
auth_data, ssl_enabled)
packet += self._connect_with_db(client_flags, database)
packet += struct.pack('<H', charset)
if client_flags & ClientFlag.PLUGIN_AUTH:
packet += auth_plugin.encode('utf8') + b'\x00'
if (client_flags & ClientFlag.CONNECT_ARGS) and conn_attrs is not None:
packet += self.make_conn_attrs(conn_attrs)
return packet
def parse_handshake(self, packet):
"""Parse a MySQL Handshake-packet"""
res = {}
res['protocol'] = struct.unpack('<xxxxB', packet[0:5])[0]
if res["protocol"] != PROTOCOL_VERSION:
raise DatabaseError("Protocol mismatch; server version = {}, "
"client version = {}".format(res["protocol"],
PROTOCOL_VERSION))
(packet, res['server_version_original']) = utils.read_string(
packet[5:], end=b'\x00')
(res['server_threadid'],
auth_data1,
capabilities1,
res['charset'],
res['server_status'],
capabilities2,
auth_data_length
) = struct.unpack('<I8sx2sBH2sBxxxxxxxxxx', packet[0:31])
res['server_version_original'] = res['server_version_original'].decode()
packet = packet[31:]
capabilities = utils.intread(capabilities1 + capabilities2)
auth_data2 = b''
if capabilities & ClientFlag.SECURE_CONNECTION:
size = min(13, auth_data_length - 8) if auth_data_length else 13
auth_data2 = packet[0:size]
packet = packet[size:]
if auth_data2[-1] == 0:
auth_data2 = auth_data2[:-1]
if capabilities & ClientFlag.PLUGIN_AUTH:
if (b'\x00' not in packet
and res['server_version_original'].startswith("5.5.8")):
# MySQL server 5.5.8 has a bug where end byte is not send
(packet, res['auth_plugin']) = (b'', packet)
else:
(packet, res['auth_plugin']) = utils.read_string(
packet, end=b'\x00')
res['auth_plugin'] = res['auth_plugin'].decode('utf-8')
else:
res['auth_plugin'] = 'mysql_native_password'
res['auth_data'] = auth_data1 + auth_data2
res['capabilities'] = capabilities
return res
def parse_auth_next_factor(self, packet):
"""Parse a MySQL AuthNextFactor packet."""
packet, status = utils.read_int(packet, 1)
if not status == 2:
raise errors.InterfaceError(
"Failed parsing AuthNextFactor packet (invalid)"
)
packet, auth_plugin = utils.read_string(packet, end=b"\x00")
return packet, auth_plugin.decode("utf-8")
def parse_ok(self, packet):
"""Parse a MySQL OK-packet"""
if not packet[4] == 0:
raise errors.InterfaceError("Failed parsing OK packet (invalid).")
ok_packet = {}
try:
ok_packet['field_count'] = struct.unpack('<xxxxB', packet[0:5])[0]
(packet, ok_packet['affected_rows']) = utils.read_lc_int(packet[5:])
(packet, ok_packet['insert_id']) = utils.read_lc_int(packet)
(ok_packet['status_flag'],
ok_packet['warning_count']) = struct.unpack('<HH', packet[0:4])
packet = packet[4:]
if packet:
(packet, ok_packet['info_msg']) = utils.read_lc_string(packet)
ok_packet['info_msg'] = ok_packet['info_msg'].decode('utf-8')
except ValueError:
raise errors.InterfaceError("Failed parsing OK packet.")
return ok_packet
def parse_column_count(self, packet):
"""Parse a MySQL packet with the number of columns in result set"""
try:
count = utils.read_lc_int(packet[4:])[1]
return count
except (struct.error, ValueError):
raise errors.InterfaceError("Failed parsing column count")
def parse_column(self, packet, encoding='utf-8'):
"""Parse a MySQL column-packet"""
(packet, _) = utils.read_lc_string(packet[4:]) # catalog
(packet, _) = utils.read_lc_string(packet) # db
(packet, _) = utils.read_lc_string(packet) # table
(packet, _) = utils.read_lc_string(packet) # org_table
(packet, name) = utils.read_lc_string(packet) # name
(packet, _) = utils.read_lc_string(packet) # org_name
try:
(
charset,
_,
column_type,
flags,
_,
) = struct.unpack('<xHIBHBxx', packet)
except struct.error:
raise errors.InterfaceError("Failed parsing column information")
return (
name.decode(encoding),
column_type,
None, # display_size
None, # internal_size
None, # precision
None, # scale
~flags & FieldFlag.NOT_NULL, # null_ok
flags, # MySQL specific
charset,
)
def parse_eof(self, packet):
"""Parse a MySQL EOF-packet"""
if packet[4] == 0:
# EOF packet deprecation
return self.parse_ok(packet)
err_msg = "Failed parsing EOF packet."
res = {}
try:
unpacked = struct.unpack('<xxxBBHH', packet)
except struct.error:
raise errors.InterfaceError(err_msg)
if not (unpacked[1] == 254 and len(packet) <= 9):
raise errors.InterfaceError(err_msg)
res['warning_count'] = unpacked[2]
res['status_flag'] = unpacked[3]
return res
def parse_statistics(self, packet, with_header=True):
"""Parse the statistics packet"""
errmsg = "Failed getting COM_STATISTICS information"
res = {}
# Information is separated by 2 spaces
if with_header:
pairs = packet[4:].split(b'\x20\x20')
else:
pairs = packet.split(b'\x20\x20')
for pair in pairs:
try:
(lbl, val) = [v.strip() for v in pair.split(b':', 2)]
except:
raise errors.InterfaceError(errmsg)
# It's either an integer or a decimal
lbl = lbl.decode('utf-8')
try:
res[lbl] = int(val)
except:
try:
res[lbl] = Decimal(val.decode('utf-8'))
except:
raise errors.InterfaceError(
"{0} ({1}:{2}).".format(errmsg, lbl, val))
return res
def read_text_result(self, sock, version, count=1):
"""Read MySQL text result
Reads all or given number of rows from the socket.
Returns a tuple with 2 elements: a list with all rows and
the EOF packet.
"""
rows = []
eof = None
rowdata = None
i = 0
while True:
if eof or i == count:
break
packet = sock.recv()
if packet.startswith(b'\xff\xff\xff'):
datas = [packet[4:]]
packet = sock.recv()
while packet.startswith(b'\xff\xff\xff'):
datas.append(packet[4:])
packet = sock.recv()
datas.append(packet[4:])
rowdata = utils.read_lc_string_list(bytearray(b'').join(datas))
elif packet[4] == 254 and packet[0] < 7:
eof = self.parse_eof(packet)
rowdata = None
else:
eof = None
rowdata = utils.read_lc_string_list(packet[4:])
if eof is None and rowdata is not None:
rows.append(rowdata)
elif eof is None and rowdata is None:
raise get_exception(packet)
i += 1
return rows, eof
def _parse_binary_integer(self, packet, field):
"""Parse an integer from a binary packet"""
if field[1] == FieldType.TINY:
format_ = '<b'
length = 1
elif field[1] == FieldType.SHORT:
format_ = '<h'
length = 2
elif field[1] in (FieldType.INT24, FieldType.LONG):
format_ = '<i'
length = 4
elif field[1] == FieldType.LONGLONG:
format_ = '<q'
length = 8
if field[7] & FieldFlag.UNSIGNED:
format_ = format_.upper()
return (packet[length:], struct.unpack(format_, packet[0:length])[0])
def _parse_binary_float(self, packet, field):
"""Parse a float/double from a binary packet"""
if field[1] == FieldType.DOUBLE:
length = 8
format_ = '<d'
else:
length = 4
format_ = '<f'
return (packet[length:], struct.unpack(format_, packet[0:length])[0])
def _parse_binary_new_decimal(self, packet, charset='utf8'):
"""Parse a New Decimal from a binary packet"""
(packet, value) = utils.read_lc_string(packet)
return (packet, Decimal(value.decode(charset)))
def _parse_binary_timestamp(self, packet, field):
"""Parse a timestamp from a binary packet"""
length = packet[0]
value = None
if length == 4:
value = datetime.date(
year=struct.unpack('<H', packet[1:3])[0],
month=packet[3],
day=packet[4])
elif length >= 7:
mcs = 0
if length == 11:
mcs = struct.unpack('<I', packet[8:length + 1])[0]
value = datetime.datetime(
year=struct.unpack('<H', packet[1:3])[0],
month=packet[3],
day=packet[4],
hour=packet[5],
minute=packet[6],
second=packet[7],
microsecond=mcs)
return (packet[length + 1:], value)
def _parse_binary_time(self, packet, field):
"""Parse a time value from a binary packet"""
length = packet[0]
data = packet[1:length + 1]
mcs = 0
if length > 8:
mcs = struct.unpack('<I', data[8:])[0]
days = struct.unpack('<I', data[1:5])[0]
if data[0] == 1:
days *= -1
tmp = datetime.timedelta(days=days,
seconds=data[7],
microseconds=mcs,
minutes=data[6],
hours=data[5])
return (packet[length + 1:], tmp)
def _parse_binary_values(self, fields, packet, charset='utf-8'):
"""Parse values from a binary result packet"""
null_bitmap_length = (len(fields) + 7 + 2) // 8
null_bitmap = [int(i) for i in packet[0:null_bitmap_length]]
packet = packet[null_bitmap_length:]
values = []
for pos, field in enumerate(fields):
if null_bitmap[int((pos+2)/8)] & (1 << (pos + 2) % 8):
values.append(None)
continue
elif field[1] in (FieldType.TINY, FieldType.SHORT,
FieldType.INT24,
FieldType.LONG, FieldType.LONGLONG):
(packet, value) = self._parse_binary_integer(packet, field)
values.append(value)
elif field[1] in (FieldType.DOUBLE, FieldType.FLOAT):
(packet, value) = self._parse_binary_float(packet, field)
values.append(value)
elif field[1] == FieldType.NEWDECIMAL:
(packet, value) = self._parse_binary_new_decimal(packet, charset)
values.append(value)
elif field[1] in (FieldType.DATETIME, FieldType.DATE,
FieldType.TIMESTAMP):
(packet, value) = self._parse_binary_timestamp(packet, field)
values.append(value)
elif field[1] == FieldType.TIME:
(packet, value) = self._parse_binary_time(packet, field)
values.append(value)
else:
(packet, value) = utils.read_lc_string(packet)
values.append(value.decode(charset))
return tuple(values)
def read_binary_result(self, sock, columns, count=1, charset='utf-8'):
"""Read MySQL binary protocol result
Reads all or given number of binary resultset rows from the socket.
"""
rows = []
eof = None
values = None
i = 0
while True:
if eof is not None:
break
if i == count:
break
packet = sock.recv()
if packet[4] == 254:
eof = self.parse_eof(packet)
values = None
elif packet[4] == 0:
eof = None
values = self._parse_binary_values(columns, packet[5:], charset)
if eof is None and values is not None:
rows.append(values)
elif eof is None and values is None:
raise get_exception(packet)
i += 1
return (rows, eof)
def parse_binary_prepare_ok(self, packet):
"""Parse a MySQL Binary Protocol OK packet"""
if not packet[4] == 0:
raise errors.InterfaceError("Failed parsing Binary OK packet")
ok_pkt = {}
try:
(packet, ok_pkt['statement_id']) = utils.read_int(packet[5:], 4)
(packet, ok_pkt['num_columns']) = utils.read_int(packet, 2)
(packet, ok_pkt['num_params']) = utils.read_int(packet, 2)
packet = packet[1:] # Filler 1 * \x00
(packet, ok_pkt['warning_count']) = utils.read_int(packet, 2)
except ValueError:
raise errors.InterfaceError("Failed parsing Binary OK packet")
return ok_pkt
def _prepare_binary_integer(self, value):
"""Prepare an integer for the MySQL binary protocol"""
field_type = None
flags = 0
if value < 0:
if value >= -128:
format_ = '<b'
field_type = FieldType.TINY
elif value >= -32768:
format_ = '<h'
field_type = FieldType.SHORT
elif value >= -2147483648:
format_ = '<i'
field_type = FieldType.LONG
else:
format_ = '<q'
field_type = FieldType.LONGLONG
else:
flags = 128
if value <= 255:
format_ = '<B'
field_type = FieldType.TINY
elif value <= 65535:
format_ = '<H'
field_type = FieldType.SHORT
elif value <= 4294967295:
format_ = '<I'
field_type = FieldType.LONG
else:
field_type = FieldType.LONGLONG
format_ = '<Q'
return (struct.pack(format_, value), field_type, flags)
def _prepare_binary_timestamp(self, value):
"""Prepare a timestamp object for the MySQL binary protocol
This method prepares a timestamp of type datetime.datetime or
datetime.date for sending over the MySQL binary protocol.
A tuple is returned with the prepared value and field type
as elements.
Raises ValueError when the argument value is of invalid type.
Returns a tuple.
"""
if isinstance(value, datetime.datetime):
field_type = FieldType.DATETIME
elif isinstance(value, datetime.date):
field_type = FieldType.DATE
else:
raise ValueError(
"Argument must a datetime.datetime or datetime.date")
packed = (utils.int2store(value.year) +
utils.int1store(value.month) +
utils.int1store(value.day))
if isinstance(value, datetime.datetime):
packed = (packed + utils.int1store(value.hour) +
utils.int1store(value.minute) +
utils.int1store(value.second))
if value.microsecond > 0:
packed += utils.int4store(value.microsecond)
packed = utils.int1store(len(packed)) + packed
return (packed, field_type)
def _prepare_binary_time(self, value):
"""Prepare a time object for the MySQL binary protocol
This method prepares a time object of type datetime.timedelta or
datetime.time for sending over the MySQL binary protocol.
A tuple is returned with the prepared value and field type
as elements.
Raises ValueError when the argument value is of invalid type.
Returns a tuple.
"""
if not isinstance(value, (datetime.timedelta, datetime.time)):
raise ValueError(
"Argument must a datetime.timedelta or datetime.time")
field_type = FieldType.TIME
negative = 0
mcs = None
packed = b''
if isinstance(value, datetime.timedelta):
if value.days < 0:
negative = 1
(hours, remainder) = divmod(value.seconds, 3600)
(mins, secs) = divmod(remainder, 60)
packed += (utils.int4store(abs(value.days)) +
utils.int1store(hours) +
utils.int1store(mins) +
utils.int1store(secs))
mcs = value.microseconds
else:
packed += (utils.int4store(0) +
utils.int1store(value.hour) +
utils.int1store(value.minute) +
utils.int1store(value.second))
mcs = value.microsecond
if mcs:
packed += utils.int4store(mcs)
packed = utils.int1store(negative) + packed
packed = utils.int1store(len(packed)) + packed
return (packed, field_type)
def _prepare_stmt_send_long_data(self, statement, param, data):
"""Prepare long data for prepared statements
Returns a string.
"""
packet = (
utils.int4store(statement) +
utils.int2store(param) +
data)
return packet
def make_stmt_execute(self, statement_id, data=(), parameters=(),
flags=0, long_data_used=None, charset='utf8',
query_attrs=None, converter_str_fallback=False):
"""Make a MySQL packet with the Statement Execute command"""
iteration_count = 1
null_bitmap = [0] * ((len(data) + 7) // 8)
values = []
types = []
packed = b''
data_len = len(data)
query_attr_names = []
flags = flags if not query_attrs else flags + PARAMETER_COUNT_AVAILABLE
if charset == 'utf8mb4':
charset = 'utf8'
if long_data_used is None:
long_data_used = {}
if query_attrs:
data = list(data)
for _, attr_val in query_attrs:
data.append(attr_val)
null_bitmap = [0] * ((len(data) + 7) // 8)
if parameters or data:
if data_len != len(parameters):
raise errors.InterfaceError(
"Failed executing prepared statement: data values does not"
" match number of parameters")
for pos, _ in enumerate(data):
value = data[pos]
_flags = 0
if value is None:
null_bitmap[(pos // 8)] |= 1 << (pos % 8)
types.append(utils.int1store(FieldType.NULL) +
utils.int1store(_flags))
continue
elif pos in long_data_used:
if long_data_used[pos][0]:
# We suppose binary data
field_type = FieldType.BLOB
else:
# We suppose text data
field_type = FieldType.STRING
elif isinstance(value, int):
(packed, field_type,
_flags) = self._prepare_binary_integer(value)
values.append(packed)
elif isinstance(value, str):
value = value.encode(charset)
values.append(utils.lc_int(len(value)) + value)
field_type = FieldType.VARCHAR
elif isinstance(value, bytes):
values.append(utils.lc_int(len(value)) + value)
field_type = FieldType.BLOB
elif isinstance(value, Decimal):
values.append(
utils.lc_int(len(str(value).encode(
charset))) + str(value).encode(charset))
field_type = FieldType.DECIMAL
elif isinstance(value, float):
values.append(struct.pack('<d', value))
field_type = FieldType.DOUBLE
elif isinstance(value, (datetime.datetime, datetime.date)):
(packed, field_type) = self._prepare_binary_timestamp(
value)
values.append(packed)
elif isinstance(value, (datetime.timedelta, datetime.time)):
(packed, field_type) = self._prepare_binary_time(value)
values.append(packed)
elif converter_str_fallback:
value = str(value).encode(charset)
values.append(utils.lc_int(len(value)) + value)
field_type = FieldType.STRING
else:
raise errors.ProgrammingError(
"MySQL binary protocol can not handle "
"'{classname}' objects".format(
classname=value.__class__.__name__))
types.append(utils.int1store(field_type) +
utils.int1store(_flags))
if query_attrs and pos+1 > data_len:
name = query_attrs[pos - data_len][0].encode(charset)
query_attr_names.append(
utils.lc_int(len(name)) + name)
packet = (
utils.int4store(statement_id) +
utils.int1store(flags) +
utils.int4store(iteration_count))
# if (num_params > 0 || (CLIENT_QUERY_ATTRIBUTES \
# && (flags & PARAMETER_COUNT_AVAILABLE)) {
if query_attrs is not None:
parameter_count = data_len + len(query_attrs)
else:
parameter_count = data_len
if parameter_count:
# if CLIENT_QUERY_ATTRIBUTES is on
if query_attrs is not None:
packet += utils.lc_int(parameter_count)
packet += (
b''.join([struct.pack('B', bit) for bit in null_bitmap]) +
utils.int1store(1))
count = 0
for a_type in types:
packet += a_type
# if CLIENT_QUERY_ATTRIBUTES is on {
# string<lenenc> parameter_name Name of the parameter
# or empty if not present
# } if CLIENT_QUERY_ATTRIBUTES is on
if query_attrs is not None:
if count+1 > data_len:
packet += query_attr_names[count - data_len]
else:
packet += b'\x00'
count+=1
for a_value in values:
packet += a_value
return packet
def parse_auth_switch_request(self, packet):
"""Parse a MySQL AuthSwitchRequest-packet"""
if not packet[4] == 254:
raise errors.InterfaceError(
"Failed parsing AuthSwitchRequest packet")
(packet, plugin_name) = utils.read_string(packet[5:], end=b'\x00')
if packet and packet[-1] == 0:
packet = packet[:-1]
return plugin_name.decode('utf8'), packet
def parse_auth_more_data(self, packet):
"""Parse a MySQL AuthMoreData-packet"""
if not packet[4] == 1:
raise errors.InterfaceError(
"Failed parsing AuthMoreData packet")
return packet[5:]

View File

@@ -0,0 +1,634 @@
# Copyright (c) 2009, 2021, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Utilities
"""
import os
import subprocess
from stringprep import (in_table_a1, in_table_b1, in_table_c11, in_table_c12,
in_table_c21_c22, in_table_c3, in_table_c4, in_table_c5,
in_table_c6, in_table_c7, in_table_c8, in_table_c9,
in_table_c12, in_table_d1, in_table_d2)
import platform
import struct
import sys
import unicodedata
from decimal import Decimal
from functools import lru_cache
from .custom_types import HexLiteral
__MYSQL_DEBUG__ = False
NUMERIC_TYPES = (int, float, Decimal, HexLiteral)
def intread(buf):
"""Unpacks the given buffer to an integer"""
try:
if isinstance(buf, int):
return buf
length = len(buf)
if length == 1:
return buf[0]
elif length <= 4:
tmp = buf + b'\x00'*(4-length)
return struct.unpack('<I', tmp)[0]
tmp = buf + b'\x00'*(8-length)
return struct.unpack('<Q', tmp)[0]
except:
raise
def int1store(i):
"""
Takes an unsigned byte (1 byte) and packs it as a bytes-object.
Returns string.
"""
if i < 0 or i > 255:
raise ValueError('int1store requires 0 <= i <= 255')
else:
return bytearray(struct.pack('<B', i))
def int2store(i):
"""
Takes an unsigned short (2 bytes) and packs it as a bytes-object.
Returns string.
"""
if i < 0 or i > 65535:
raise ValueError('int2store requires 0 <= i <= 65535')
else:
return bytearray(struct.pack('<H', i))
def int3store(i):
"""
Takes an unsigned integer (3 bytes) and packs it as a bytes-object.
Returns string.
"""
if i < 0 or i > 16777215:
raise ValueError('int3store requires 0 <= i <= 16777215')
else:
return bytearray(struct.pack('<I', i)[0:3])
def int4store(i):
"""
Takes an unsigned integer (4 bytes) and packs it as a bytes-object.
Returns string.
"""
if i < 0 or i > 4294967295:
raise ValueError('int4store requires 0 <= i <= 4294967295')
else:
return bytearray(struct.pack('<I', i))
def int8store(i):
"""
Takes an unsigned integer (8 bytes) and packs it as string.
Returns string.
"""
if i < 0 or i > 18446744073709551616:
raise ValueError('int8store requires 0 <= i <= 2^64')
else:
return bytearray(struct.pack('<Q', i))
def intstore(i):
"""
Takes an unsigned integers and packs it as a bytes-object.
This function uses int1store, int2store, int3store,
int4store or int8store depending on the integer value.
returns string.
"""
if i < 0 or i > 18446744073709551616:
raise ValueError('intstore requires 0 <= i <= 2^64')
if i <= 255:
formed_string = int1store
elif i <= 65535:
formed_string = int2store
elif i <= 16777215:
formed_string = int3store
elif i <= 4294967295:
formed_string = int4store
else:
formed_string = int8store
return formed_string(i)
def lc_int(i):
"""
Takes an unsigned integer and packs it as bytes,
with the information of how much bytes the encoded int takes.
"""
if i < 0 or i > 18446744073709551616:
raise ValueError('Requires 0 <= i <= 2^64')
if i < 251:
return bytearray(struct.pack('<B', i))
elif i <= 65535:
return b'\xfc' + bytearray(struct.pack('<H', i))
elif i <= 16777215:
return b'\xfd' + bytearray(struct.pack('<I', i)[0:3])
return b'\xfe' + bytearray(struct.pack('<Q', i))
def read_bytes(buf, size):
"""
Reads bytes from a buffer.
Returns a tuple with buffer less the read bytes, and the bytes.
"""
res = buf[0:size]
return (buf[size:], res)
def read_lc_string(buf):
"""
Takes a buffer and reads a length coded string from the start.
This is how Length coded strings work
If the string is 250 bytes long or smaller, then it looks like this:
<-- 1b -->
+----------+-------------------------
| length | a string goes here
+----------+-------------------------
If the string is bigger than 250, then it looks like this:
<- 1b -><- 2/3/8 ->
+------+-----------+-------------------------
| type | length | a string goes here
+------+-----------+-------------------------
if type == \xfc:
length is code in next 2 bytes
elif type == \xfd:
length is code in next 3 bytes
elif type == \xfe:
length is code in next 8 bytes
NULL has a special value. If the buffer starts with \xfb then
it's a NULL and we return None as value.
Returns a tuple (trucated buffer, bytes).
"""
if buf[0] == 251: # \xfb
# NULL value
return (buf[1:], None)
length = lsize = 0
fst = buf[0]
if fst <= 250: # \xFA
length = fst
return (buf[1 + length:], buf[1:length + 1])
elif fst == 252:
lsize = 2
elif fst == 253:
lsize = 3
if fst == 254:
lsize = 8
length = intread(buf[1:lsize + 1])
return (buf[lsize + length + 1:], buf[lsize + 1:length + lsize + 1])
def read_lc_string_list(buf):
"""Reads all length encoded strings from the given buffer
Returns a list of bytes
"""
byteslst = []
sizes = {252: 2, 253: 3, 254: 8}
buf_len = len(buf)
pos = 0
while pos < buf_len:
first = buf[pos]
if first == 255:
# Special case when MySQL error 1317 is returned by MySQL.
# We simply return None.
return None
if first == 251:
# NULL value
byteslst.append(None)
pos += 1
else:
if first <= 250:
length = first
byteslst.append(buf[(pos + 1):length + (pos + 1)])
pos += 1 + length
else:
lsize = 0
try:
lsize = sizes[first]
except KeyError:
return None
length = intread(buf[(pos + 1):lsize + (pos + 1)])
byteslst.append(
buf[pos + 1 + lsize:length + lsize + (pos + 1)])
pos += 1 + lsize + length
return tuple(byteslst)
def read_string(buf, end=None, size=None):
"""
Reads a string up until a character or for a given size.
Returns a tuple (trucated buffer, string).
"""
if end is None and size is None:
raise ValueError('read_string() needs either end or size')
if end is not None:
try:
idx = buf.index(end)
except ValueError:
raise ValueError("end byte not present in buffer")
return (buf[idx + 1:], buf[0:idx])
elif size is not None:
return read_bytes(buf, size)
raise ValueError('read_string() needs either end or size (weird)')
def read_int(buf, size):
"""Read an integer from buffer
Returns a tuple (truncated buffer, int)
"""
try:
res = intread(buf[0:size])
except:
raise
return (buf[size:], res)
def read_lc_int(buf):
"""
Takes a buffer and reads an length code string from the start.
Returns a tuple with buffer less the integer and the integer read.
"""
if not buf:
raise ValueError("Empty buffer.")
lcbyte = buf[0]
if lcbyte == 251:
return (buf[1:], None)
elif lcbyte < 251:
return (buf[1:], int(lcbyte))
elif lcbyte == 252:
return (buf[3:], struct.unpack('<xH', buf[0:3])[0])
elif lcbyte == 253:
return (buf[4:], struct.unpack('<I', buf[1:4] + b'\x00')[0])
elif lcbyte == 254:
return (buf[9:], struct.unpack('<xQ', buf[0:9])[0])
else:
raise ValueError("Failed reading length encoded integer")
#
# For debugging
#
def _digest_buffer(buf):
"""Debug function for showing buffers"""
if not isinstance(buf, str):
return ''.join(["\\x%02x" % c for c in buf])
return ''.join(["\\x%02x" % ord(c) for c in buf])
def print_buffer(abuffer, prefix=None, limit=30):
"""Debug function printing output of _digest_buffer()"""
if prefix:
if limit and limit > 0:
digest = _digest_buffer(abuffer[0:limit])
else:
digest = _digest_buffer(abuffer)
print(prefix + ': ' + digest)
else:
print(_digest_buffer(abuffer))
def _parse_os_release():
"""Parse the contents of /etc/os-release file.
Returns:
A dictionary containing release information.
"""
distro = {}
os_release_file = os.path.join("/etc", "os-release")
if not os.path.exists(os_release_file):
return distro
with open(os_release_file) as file_obj:
for line in file_obj:
key_value = line.split("=")
if len(key_value) != 2:
continue
key = key_value[0].lower()
value = key_value[1].rstrip("\n").strip('"')
distro[key] = value
return distro
def _parse_lsb_release():
"""Parse the contents of /etc/lsb-release file.
Returns:
A dictionary containing release information.
"""
distro = {}
lsb_release_file = os.path.join("/etc", "lsb-release")
if os.path.exists(lsb_release_file):
with open(lsb_release_file) as file_obj:
for line in file_obj:
key_value = line.split("=")
if len(key_value) != 2:
continue
key = key_value[0].lower()
value = key_value[1].rstrip("\n").strip('"')
distro[key] = value
return distro
def _parse_lsb_release_command():
"""Parse the output of the lsb_release command.
Returns:
A dictionary containing release information.
"""
distro = {}
with open(os.devnull, "w") as devnull:
try:
stdout = subprocess.check_output(
("lsb_release", "-a"), stderr=devnull)
except OSError:
return None
lines = stdout.decode(sys.getfilesystemencoding()).splitlines()
for line in lines:
key_value = line.split(":")
if len(key_value) != 2:
continue
key = key_value[0].replace(" ", "_").lower()
value = key_value[1].strip("\t")
distro[key] = value
return distro
def linux_distribution():
"""Tries to determine the name of the Linux OS distribution name.
First tries to get information from ``/etc/os-release`` file.
If fails, tries to get the information of ``/etc/lsb-release`` file.
And finally the information of ``lsb-release`` command.
Returns:
A tuple with (`name`, `version`, `codename`)
"""
distro = _parse_lsb_release()
if distro:
return (distro.get("distrib_id", ""),
distro.get("distrib_release", ""),
distro.get("distrib_codename", ""))
distro = _parse_lsb_release_command()
if distro:
return (distro.get("distributor_id", ""),
distro.get("release", ""),
distro.get("codename", ""))
distro = _parse_os_release()
if distro:
return (distro.get("name", ""),
distro.get("version_id", ""),
distro.get("version_codename", ""))
return ("", "", "")
def _get_unicode_read_direction(unicode_str):
"""Get the readiness direction of the unicode string.
We assume that the direction is "L-to-R" if the first character does not
indicate the direction is "R-to-L" or an "AL" (Arabic Letter).
"""
if unicode_str and unicodedata.bidirectional(unicode_str[0]) in ("R", "AL"):
return "R-to-L"
return "L-to-R"
def _get_unicode_direction_rule(unicode_str):
"""
1) The characters in section 5.8 MUST be prohibited.
2) If a string contains any RandALCat character, the string MUST NOT
contain any LCat character.
3) If a string contains any RandALCat character, a RandALCat
character MUST be the first character of the string, and a
RandALCat character MUST be the last character of the string.
"""
read_dir = _get_unicode_read_direction(unicode_str)
# point 3)
if read_dir == "R-to-L":
if not (in_table_d1(unicode_str[0]) and in_table_d1(unicode_str[-1])):
raise ValueError("Invalid unicode Bidirectional sequence, if the "
"first character is RandALCat, the final character"
"must be RandALCat too.")
# characters from in_table_d2 are prohibited.
return {"Bidirectional Characters requirement 2 [StringPrep, d2]":
in_table_d2}
# characters from in_table_d1 are prohibited.
return {"Bidirectional Characters requirement 2 [StringPrep, d2]":
in_table_d1}
def validate_normalized_unicode_string(normalized_str):
"""Check for Prohibited Output according to rfc4013 profile.
This profile specifies the following characters as prohibited input:
- Non-ASCII space characters [StringPrep, C.1.2]
- ASCII control characters [StringPrep, C.2.1]
- Non-ASCII control characters [StringPrep, C.2.2]
- Private Use characters [StringPrep, C.3]
- Non-character code points [StringPrep, C.4]
- Surrogate code points [StringPrep, C.5]
- Inappropriate for plain text characters [StringPrep, C.6]
- Inappropriate for canonical representation characters [StringPrep, C.7]
- Change display properties or deprecated characters [StringPrep, C.8]
- Tagging characters [StringPrep, C.9]
In addition of checking of Bidirectional Characters [StringPrep, Section 6]
and the Unassigned Code Points [StringPrep, A.1].
Returns:
A tuple with ("probited character", "breaked_rule")
"""
rules = {
"Space characters that contains the ASCII code points": in_table_c11,
"Space characters non-ASCII code points": in_table_c12,
"Unassigned Code Points [StringPrep, A.1]": in_table_a1,
"Non-ASCII space characters [StringPrep, C.1.2]": in_table_c12,
"ASCII control characters [StringPrep, C.2.1]": in_table_c21_c22,
"Private Use characters [StringPrep, C.3]": in_table_c3,
"Non-character code points [StringPrep, C.4]": in_table_c4,
"Surrogate code points [StringPrep, C.5]": in_table_c5,
"Inappropriate for plain text characters [StringPrep, C.6]": in_table_c6,
"Inappropriate for canonical representation characters [StringPrep, C.7]": in_table_c7,
"Change display properties or deprecated characters [StringPrep, C.8]": in_table_c8,
"Tagging characters [StringPrep, C.9]": in_table_c9
}
try:
rules.update(_get_unicode_direction_rule(normalized_str))
except ValueError as err:
return normalized_str, str(err)
for char in normalized_str:
for rule in rules:
if rules[rule](char) and char != u' ':
return char, rule
return None
def normalize_unicode_string(a_string):
"""normalizes a unicode string according to rfc4013
Normalization of a unicode string according to rfc4013: The SASLprep profile
of the "stringprep" algorithm.
Normalization Unicode equivalence is the specification by the Unicode
character encoding standard that some sequences of code points represent
essentially the same character.
This method normalizes using the Normalization Form Compatibility
Composition (NFKC), as described in rfc4013 2.2.
Returns:
Normalized unicode string according to rfc4013.
"""
# Per rfc4013 2.1. Mapping
# non-ASCII space characters [StringPrep, C.1.2] are mapped to ' ' (U+0020)
# "commonly mapped to nothing" characters [StringPrep, B.1] are mapped to ''
nstr_list = [
u' ' if in_table_c12(char) else u'' if in_table_b1(char) else char
for char in a_string]
nstr = u''.join(nstr_list)
# Per rfc4013 2.2. Use NFKC Normalization Form Compatibility Composition
# Characters are decomposed by compatibility, then recomposed by canonical
# equivalence.
nstr = unicodedata.normalize('NFKC', nstr)
if not nstr:
# Normilization results in empty string.
return u''
return nstr
def make_abc(base_class):
"""Decorator used to create a abstract base class.
We use this decorator to create abstract base classes instead of
using the abc-module. The decorator makes it possible to do the
same in both Python v2 and v3 code.
"""
def wrapper(class_):
"""Wrapper"""
attrs = class_.__dict__.copy()
for attr in '__dict__', '__weakref__':
attrs.pop(attr, None) # ignore missing attributes
bases = class_.__bases__
bases = (class_,) + bases
return base_class(class_.__name__, bases, attrs)
return wrapper
def init_bytearray(payload=b'', encoding='utf-8'):
"""Initialize a bytearray from the payload."""
if isinstance(payload, bytearray):
return payload
if isinstance(payload, int):
return bytearray(payload)
if not isinstance(payload, bytes):
try:
return bytearray(payload.encode(encoding=encoding))
except AttributeError:
raise ValueError("payload must be a str or bytes")
return bytearray(payload)
@lru_cache()
def get_platform():
"""Return a dict with the platform arch and OS version."""
plat = {"arch": None, "version": None}
if os.name == "nt":
if "64" in platform.architecture()[0]:
plat["arch"] = "x86_64"
elif "32" in platform.architecture()[0]:
plat["arch"] = "i386"
else:
plat["arch"] = platform.architecture()
plat["version"] = "Windows-{}".format(platform.win32_ver()[1])
else:
plat["arch"] = platform.machine()
if platform.system() == "Darwin":
plat["version"] = "{}-{}".format("macOS", platform.mac_ver()[0])
else:
plat["version"] = "-".join(linux_distribution()[0:2])
return plat

View File

@@ -0,0 +1,44 @@
# Copyright (c) 2012, 2021, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""MySQL Connector/Python version information
The file version.py gets installed and is available after installation
as mysql.connector.version.
"""
VERSION = (8, 0, 28, '', 1)
if VERSION[3] and VERSION[4]:
VERSION_TEXT = '{0}.{1}.{2}{3}{4}'.format(*VERSION)
else:
VERSION_TEXT = '{0}.{1}.{2}'.format(*VERSION[0:3])
VERSION_EXTRA = ''
LICENSE = 'GPLv2 with FOSS License Exception'
EDITION = '' # Added in package names, after the version

Some files were not shown because too many files have changed in this diff Show More