Bring MySQL and Postgresql up to date

This commit is contained in:
Doug Blank 2016-04-23 11:46:22 -04:00
parent 230c9d6bd0
commit cfd686ff59
4 changed files with 75 additions and 24 deletions

View File

@ -170,7 +170,8 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
Return all seconday fields and their types
"""
from .handle import HandleClass
return ([(key, value) for (key, value) in cls.get_schema().items()
return ([(key.lower(), value)
for (key, value) in cls.get_schema().items()
if value in [str, int, float, bool] or
isinstance(value, HandleClass)] +
cls.get_extra_secondary_fields())

View File

@ -1080,7 +1080,7 @@ class DBAPI(DbGeneric):
"""
# first build sort order:
sorted_items = []
query = "SELECT blob_data FROM %s;" % class_.__name__
query = "SELECT blob_data FROM %s;" % class_.__name__.lower()
self.dbapi.execute(query)
rows = self.dbapi.fetchall()
for row in rows:
@ -1108,11 +1108,11 @@ class DBAPI(DbGeneric):
return
## Continue with dbapi select
if order_by is None:
query = "SELECT blob_data FROM %s;" % class_.__name__
query = "SELECT blob_data FROM %s;" % class_.__name__.lower()
else:
order_phrases = ["%s %s" % (self._hash_name(class_.__name__, class_.get_field_alias(field)), direction)
for (field, direction) in order_by]
query = "SELECT blob_data FROM %s ORDER BY %s;" % (class_.__name__, ", ".join(order_phrases))
query = "SELECT blob_data FROM %s ORDER BY %s;" % (class_.__name__.lower(), ", ".join(order_phrases))
self.dbapi.execute(query)
rows = self.dbapi.fetchall()
for row in rows:
@ -1613,11 +1613,12 @@ class DBAPI(DbGeneric):
if not hasattr(self.get_table_func(table,"class_func"), "get_secondary_fields"):
continue
# do a select on all; if it works, then it is ok; else, check them all
table_name = table.lower()
try:
fields = [self._hash_name(table, field) for (field, ptype) in
self.get_table_func(table,"class_func").get_secondary_fields()]
if fields:
self.dbapi.execute("select %s from %s limit 1;" % (", ".join(fields), table))
self.dbapi.execute("select %s from %s limit 1;" % (", ".join(fields), table_name))
# if no error, continue
LOG.info("Table %s is up to date" % table)
continue
@ -1631,12 +1632,12 @@ class DBAPI(DbGeneric):
sql_type = self._sql_type(python_type)
try:
# test to see if it exists:
self.dbapi.execute("SELECT %s FROM %s LIMIT 1;" % (field, table))
self.dbapi.execute("SELECT %s FROM %s LIMIT 1;" % (field, table_name))
LOG.info(" Table %s, field %s is up to date" % (table, field))
except:
# if not, let's add it
LOG.info(" Table %s, field %s was added" % (table, field))
self.dbapi.execute("ALTER TABLE %s ADD COLUMN %s %s;" % (table, field, sql_type))
self.dbapi.execute("ALTER TABLE %s ADD COLUMN %s %s;" % (table_name, field, sql_type))
altered = True
if altered:
LOG.info("Table %s is being committed, rebuilt, and indexed..." % table)
@ -1656,10 +1657,11 @@ class DBAPI(DbGeneric):
"""
Create secondary indexes for just this table.
"""
table_name = table.lower()
for fields in self.get_table_func(table,"class_func").get_index_fields():
for field in fields:
field = self._hash_name(table, field)
self.dbapi.try_execute("CREATE INDEX %s_%s ON %s(%s);" % (table, field, table, field))
self.dbapi.try_execute("CREATE INDEX %s_%s ON %s(%s);" % (table, field, table_name, field))
def update_secondary_values_all(self):
"""
@ -1696,8 +1698,16 @@ class DBAPI(DbGeneric):
sets.append("%s = ?" % field)
values.append(value)
if len(values) > 0:
self.dbapi.execute("UPDATE %s SET %s where handle = ?;" % (table, ", ".join(sets)),
values + [item.handle])
table_name = table.lower()
self.dbapi.execute("UPDATE %s SET %s where handle = ?;" % (table_name, ", ".join(sets)),
self._sql_cast_list(table, sets, values) + [item.handle])
def _sql_cast_list(self, table, fields, values):
"""
Given a list of field names and values, return the values
in the appropriate type.
"""
return [v if type(v) is not bool else int(v) for v in values]
def _sql_repr(self, value):
"""
@ -1823,6 +1833,7 @@ class DBAPI(DbGeneric):
["handle"]) # handle is a sql field, but not listed in secondaries
# If no fields, then we need objects:
# Check to see if where matches SQL fields:
table_name = table.lower()
if ((not self._check_where_fields(table, where, secondary_fields)) or
(not self._check_order_by_fields(table, order_by, secondary_fields))):
# If not, then need to do select via Python:
@ -1849,11 +1860,11 @@ class DBAPI(DbGeneric):
select_fields = ["1"]
if start:
query = "SELECT %s FROM %s %s %s LIMIT %s, %s" % (
", ".join(select_fields), table, where_clause, order_clause, start, limit
", ".join(select_fields), table_name, where_clause, order_clause, start, limit
)
else:
query = "SELECT %s FROM %s %s %s LIMIT %s" % (
", ".join(select_fields), table, where_clause, order_clause, limit
", ".join(select_fields), table_name, where_clause, order_clause, limit
)
if get_count_only:
self.dbapi.execute("SELECT count(1) from (%s);" % query)

View File

@ -20,11 +20,27 @@ class MySQL(object):
def __init__(self, *args, **kwargs):
self.connection = MySQLdb.connect(*args, **kwargs)
self.connection.autocommit(True)
self.cursor = self.connection.cursor()
def execute(self, query, args=[]):
## Workaround: no qmark support
def _hack_query(self, query):
## Workaround: no qmark support:
query = query.replace("?", "%s")
query = query.replace("INTEGER", "INT")
query = query.replace("REAL", "DOUBLE")
query = query.replace("change", "change_")
query = query.replace("desc", "desc_")
## LIMIT offset, count
## count can be -1, for all
## LIMIT -1
## LIMIT offset, -1
query = query.replace("LIMIT -1",
"LIMIT 18446744073709551615") ## largest maxint
#query = query.replace("LIMIT -1", "")
return query
def execute(self, query, args=[]):
query = self._hack_query(query)
self.cursor.execute(query, args)
def fetchone(self):
@ -34,12 +50,16 @@ class MySQL(object):
return self.cursor.fetchall()
def commit(self):
self.connection.commit()
self.cursor.execute("COMMIT;");
def begin(self):
self.cursor.execute("BEGIN;");
def rollback(self):
self.connection.rollback()
def try_execute(self, sql):
query = self._hack_query(sql)
try:
self.cursor.execute(sql)
except Exception as exc:

View File

@ -1,6 +1,6 @@
import pg8000
import psycopg2
pg8000.paramstyle = 'qmark'
psycopg2.paramstyle = 'format'
class Postgresql(object):
@classmethod
@ -12,18 +12,31 @@ class Postgresql(object):
summary = {
"DB-API version": "2.0",
"Database SQL type": cls.__name__,
"Database SQL module": "pg8000",
"Database SQL module version": pg8000.__version__,
"Database SQL module location": pg8000.__file__,
"Database SQL module": "psycopg2",
"Database SQL module version": psycopg2.__version__,
"Database SQL module location": psycopg2.__file__,
}
return summary
def __init__(self, *args, **kwargs):
self.connection = pg8000.connect(*args, **kwargs)
self.connection = psycopg2.connect(*args, **kwargs)
self.connection.autocommit = True
self.cursor = self.connection.cursor()
def execute(self, *args, **kwargs):
self.cursor.execute(*args, **kwargs)
sql = args[0]
sql = sql.replace("?", "%s")
sql = sql.replace("REGEXP", "~")
sql = sql.replace("desc", "desc_")
if len(args) > 1:
args = args[1]
else:
args = None
try:
self.cursor.execute(sql, args, **kwargs)
except:
self.cursor.execute("rollback")
raise
def fetchone(self):
return self.cursor.fetchone()
@ -31,18 +44,24 @@ class Postgresql(object):
def fetchall(self):
return self.cursor.fetchall()
def begin(self):
self.cursor.execute("BEGIN;")
def commit(self):
self.connection.commit()
self.cursor.execute("COMMIT;")
def rollback(self):
self.connection.rollback()
def try_execute(self, sql):
sql = sql.replace("?", "%s")
sql = sql.replace("BLOB", "bytea")
sql = sql.replace("desc", "desc_")
try:
self.cursor.execute(sql)
except Exception as exc:
pass
self.cursor.execute("rollback")
#print("ERROR:", sql)
#print(str(exc))
def close(self):