Refinements in db.select; allow joins

This commit is contained in:
Doug Blank 2016-01-10 10:55:45 -05:00
parent 177e30ee62
commit 22fa6ed63a
4 changed files with 178 additions and 74 deletions

View File

@ -31,6 +31,7 @@ from this class.
# #
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
import re import re
import time
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
# #
@ -1887,23 +1888,67 @@ class DbWriteBase(DbReadBase):
sort - use sort order (argument to DB.get_X_handles) sort - use sort order (argument to DB.get_X_handles)
start - position to start start - position to start
limit - count to get; -1 for all limit - count to get; -1 for all
filter - {field: (SQL string_operator, value), } filter - [["AND", [(field, SQL string_operator, value),
(field, SQL string_operator, value), ...],
["OR", [(field, SQL string_operator, value),
(field, SQL string_operator, value), ...]]
handles all SQL except for NOT expression, eg NOT x = y handles all SQL except for NOT expression, eg NOT x = y
""" """
class Result(list): class Result(list):
""" """
A list rows of just matching for this page, with total = all. A list rows of just matching for this page, with total = all,
and time = time to select.
""" """
total = 0 total = 0
def hash_name(name): def hash_name(table, name):
""" """
Used in filter to eval expressions involving selected Used in filter to eval expressions involving selected
data. data.
""" """
name = self._tables[table]["class_func"].get_field_alias(name)
return (name return (name
.replace(".", "_D_") .replace(".", "_D_")
.replace("(", "_P_") .replace("(", "_P_")
.replace(")", "_P_")) .replace(")", "_P_"))
def compare(v, op, value):
"""
Compare values in a SQL-like way
"""
if op == "=":
matched = v == value
elif op == ">":
matched = v > value
elif op == ">=":
matched = v >= value
elif op == "<":
matched = v < value
elif op == "<=":
matched = v <= value
elif op == "IN":
matched = v in value
elif op == "NI":
matched = value in v
elif op == "IS":
matched = v is value
elif op == "IS NOT":
matched = v is not value
elif op == "IS NULL":
matched = v is None
elif op == "IS NOT NULL":
matched = v is not None
elif op == "BETWEEN":
matched = value[0] <= v <= value[1]
elif op in ["<>", "!="]:
matched = v != value
elif op == "LIKE":
if value and v:
value = value.replace("%", "(.*)").replace("_", ".")
matched = re.match("^" + value + "$", v)
else:
matched = False
else:
raise Exception("invalid select operator: '%s'" % op)
return True if matched else False
# Fields is None or list, maybe containing "*": # Fields is None or list, maybe containing "*":
if fields is None: if fields is None:
fields = ["*"] fields = ["*"]
@ -1920,55 +1965,43 @@ class DbWriteBase(DbReadBase):
position = 0 position = 0
selected = 0 selected = 0
result = Result() result = Result()
start_time = time.time()
if filter: if filter:
for handle in data: for handle in data:
# have to evaluate all, because there is a filter # have to evaluate all, because there is a filter
item = self._tables[table]["handle_func"](handle) item = self._tables[table]["handle_func"](handle)
row = {} row = {}
env = {} env = {}
for field in filter.keys(): for (connector, exprs) in filter:
# just the ones we need for filter for (name, op, value) in exprs:
value = item.get_field(field) # just the ones we need for filter
env[hash_name(field)] = value value = item.get_field(name, self, ignore_errors=True)
matched = True env[hash_name(table, name)] = value
for name, (op, value) in filter.items(): for (connector, exprs) in filter:
v = eval(hash_name(name), env) if connector == "AND":
if op == "=": matched = True
matched = v == value # all must match to be true
elif op == ">": for (name, op, value) in exprs:
matched = v > value v = eval(hash_name(table, name), env)
elif op == ">=": matched = compare(v, op, value)
matched = v >= value if not matched:
elif op == "<": break
matched = v < value elif connector == "OR":
elif op == "<=": matched = False
matched = v <= value # any must match to be true
elif op == "IN": for (name, op, value) in exprs:
matched = v in value v = eval(hash_name(table, name), env)
elif op == "IS": matched = compare(v, op, value)
matched = v is value if matched:
elif op == "IS NOT": break
matched = v is not value
elif op == "IS NULL":
matched = v is None
elif op == "IS NOT NULL":
matched = v is not None
elif op == "BETWEEN":
matched = value[0] <= v <= value[1]
elif op in ["<>", "!="]:
matched = v != value
elif op == "LIKE":
value = value.replace("%", "(.*)").replace("_", ".")
matched = re.match(value, v)
else:
raise Exception("invalid select operator: '%s'" % op)
if not matched: if not matched:
break break
# else, keep going
if matched: if matched:
if selected < limit and start <= position: if selected < limit and start <= position:
# now, we get all of the fields # now, we get all of the fields
for field in fields: for field in fields:
value = item.get_field(field) value = item.get_field(field, self, ignore_errors=True)
row[field] = value row[field] = value
selected += 1 selected += 1
result.append(row) result.append(row)
@ -1982,10 +2015,11 @@ class DbWriteBase(DbReadBase):
item = self._tables[table]["handle_func"](handle) item = self._tables[table]["handle_func"](handle)
row = {} row = {}
for field in fields: for field in fields:
value = item.get_field(field) value = item.get_field(field, self, ignore_errors=True)
row[field] = value row[field] = value
result.append(row) result.append(row)
selected += 1 selected += 1
position += 1 position += 1
result.total = self._tables[table]["count_func"]() result.total = self._tables[table]["count_func"]()
result.time = time.time() - start_time
return result return result

View File

@ -22,11 +22,35 @@ class HandleClass(str):
def __init__(self, handle): def __init__(self, handle):
super(HandleClass, self).__init__() super(HandleClass, self).__init__()
def Handle(classname, handle): def join(self, database, handle):
return database._tables[self.classname]["handle_func"](handle)
@classmethod
def get_schema(cls):
from gramps.gen.lib import (Person, Family, Event, Place, Source,
MediaObject, Repository, Note, Citation)
tables = {
"Person": Person,
"Family": Family,
"Event": Event,
"Place": Place,
"Source": Source,
"MediaObject": MediaObject,
"Repository": Repository,
"Note": Note,
"Citation": Citation,
}
return tables[cls.classname].get_schema()
def Handle(_classname, handle):
if handle is None: if handle is None:
return None return None
h = HandleClass(handle) class MyHandleClass(HandleClass):
h.classname = classname """
Class created to have classname attribute.
"""
classname = _classname
h = MyHandleClass(handle)
return h return h
def __from_struct(struct): def __from_struct(struct):

View File

@ -207,6 +207,18 @@ class Person(CitationBase, NoteBase, AttributeBase, MediaBase,
for pr in self.person_ref_list] # 20 for pr in self.person_ref_list] # 20
} }
@classmethod
def field_aliases(cls):
"""
Return dictionary of alias to full field names
for this object class.
"""
return {
"given": "primary_name.first_name",
"surname": "primary_name.surname_list.0.surname",
"surnames": "primary_name.surname_list.surname",
}
@classmethod @classmethod
def get_labels(cls, _): def get_labels(cls, _):
return { return {

View File

@ -127,6 +127,21 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
""" """
return {} return {}
@classmethod
def field_aliases(cls):
"""
Return dictionary of alias to full field names
for this object class.
"""
return {}
@classmethod
def get_field_alias(cls, alias):
"""
Return full field name for an alias, if one.
"""
return cls.field_aliases().get(alias, alias)
@classmethod @classmethod
def get_schema(cls): def get_schema(cls):
""" """
@ -155,6 +170,7 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
Get the associated label given a field name of this object. Get the associated label given a field name of this object.
No index positions allowed on lists. No index positions allowed on lists.
""" """
field = cls.get_field_alias(field)
chain = field.split(".") chain = field.split(".")
ftype = cls._follow_schema_path(chain) ftype = cls._follow_schema_path(chain)
return ftype return ftype
@ -172,55 +188,73 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
elif part in schema.keys(): elif part in schema.keys():
path = schema[part] path = schema[part]
else: else:
raise Exception("No such %s in %s" % (part, schema)) raise Exception("No such '%s' in %s" % (part, list(schema.keys())))
if isinstance(path, (list, tuple)): if isinstance(path, (list, tuple)):
path = path[0] path = path[0]
return path return path
def get_field(self, field): def get_field(self, field, db=None, ignore_errors=False):
""" """
Get the value of a field. Get the value of a field.
""" """
field = self.__class__.get_field_alias(field)
chain = field.split(".") chain = field.split(".")
path = self._follow_field_path(chain) path = self._follow_field_path(chain, db, ignore_errors)
return path return path
def _follow_field_path(self, chain, ignore_errors=False): def _follow_field_path(self, chain, db=None, ignore_errors=False):
""" """
Follow a list of items. Return endpoint. Follow a list of items. Return endpoint.
With the db argument, can do joins across tables.
self - current object
""" """
path = self from .handle import HandleClass
current = self
path_to = []
parent = self
for part in chain: for part in chain:
class_ = None path_to.append(part)
if hasattr(path, part): # attribute if hasattr(current, part): # attribute
path = getattr(path, part) current = getattr(current, part)
elif part.isdigit(): # index into list elif part.isdigit(): # index into list
path = path[int(part)] current = current[int(part)]
elif part.endswith(")"): # callable continue
# parse elif isinstance(current, (list, tuple)):
function, sargs = part.split("(", 1) current = [getattr(attr, part) for attr in current]
sargs = sargs[:-1] # remove right-parent else: # part not found on this self
# eval arguments # current is a handle
args = [] # part is something on joined object
for sarg in sargs.split(","): ptype = parent.__class__.get_field_type(".".join(path_to[:-1]))
if sarg: if isinstance(ptype, HandleClass):
args.append(eval(sarg.strip())) if db:
# call # start over here:
path = getattr(path, function)(*args) try:
elif ignore_errors: parent = ptype.join(db, current)
return current = getattr(parent, part)
else: path_to = []
raise Exception("%s is not a valid field of %s; use %s" % continue
(part, path, dir(path))) except:
return path if ignore_errors:
return
else:
raise
else:
raise Exception("Can't join without database")
if ignore_errors:
return
else:
raise Exception("%s is not a valid field of %s; use %s" %
(part, current, dir(current)))
return current
def set_field(self, field, value, ignore_errors=False): def set_field(self, field, value, db=None, ignore_errors=False):
""" """
Set the value of a basic field (str, int, float, or bool). Set the value of a basic field (str, int, float, or bool).
value can be a string or actual value. value can be a string or actual value.
""" """
field = self.__class__.get_field_alias(field)
chain = field.split(".") chain = field.split(".")
path = self._follow_field_path(chain[:-1], ignore_errors) path = self._follow_field_path(chain[:-1], db, ignore_errors)
ftype = self.get_field_type(field) ftype = self.get_field_type(field)
# ftype is str, bool, float, or int # ftype is str, bool, float, or int
value = (value in ['True', True]) if ftype is bool else value value = (value in ['True', True]) if ftype is bool else value