Refinements in db.select; allow joins
This commit is contained in:
parent
177e30ee62
commit
22fa6ed63a
@ -31,6 +31,7 @@ from this class.
|
||||
#
|
||||
#-------------------------------------------------------------------------
|
||||
import re
|
||||
import time
|
||||
|
||||
#-------------------------------------------------------------------------
|
||||
#
|
||||
@ -1887,23 +1888,67 @@ class DbWriteBase(DbReadBase):
|
||||
sort - use sort order (argument to DB.get_X_handles)
|
||||
start - position to start
|
||||
limit - count to get; -1 for all
|
||||
filter - {field: (SQL string_operator, value), }
|
||||
filter - [["AND", [(field, SQL string_operator, value),
|
||||
(field, SQL string_operator, value), ...],
|
||||
["OR", [(field, SQL string_operator, value),
|
||||
(field, SQL string_operator, value), ...]]
|
||||
handles all SQL except for NOT expression, eg NOT x = y
|
||||
"""
|
||||
class Result(list):
|
||||
"""
|
||||
A list rows of just matching for this page, with total = all.
|
||||
A list rows of just matching for this page, with total = all,
|
||||
and time = time to select.
|
||||
"""
|
||||
total = 0
|
||||
def hash_name(name):
|
||||
def hash_name(table, name):
|
||||
"""
|
||||
Used in filter to eval expressions involving selected
|
||||
data.
|
||||
"""
|
||||
name = self._tables[table]["class_func"].get_field_alias(name)
|
||||
return (name
|
||||
.replace(".", "_D_")
|
||||
.replace("(", "_P_")
|
||||
.replace(")", "_P_"))
|
||||
def compare(v, op, value):
|
||||
"""
|
||||
Compare values in a SQL-like way
|
||||
"""
|
||||
if op == "=":
|
||||
matched = v == value
|
||||
elif op == ">":
|
||||
matched = v > value
|
||||
elif op == ">=":
|
||||
matched = v >= value
|
||||
elif op == "<":
|
||||
matched = v < value
|
||||
elif op == "<=":
|
||||
matched = v <= value
|
||||
elif op == "IN":
|
||||
matched = v in value
|
||||
elif op == "NI":
|
||||
matched = value in v
|
||||
elif op == "IS":
|
||||
matched = v is value
|
||||
elif op == "IS NOT":
|
||||
matched = v is not value
|
||||
elif op == "IS NULL":
|
||||
matched = v is None
|
||||
elif op == "IS NOT NULL":
|
||||
matched = v is not None
|
||||
elif op == "BETWEEN":
|
||||
matched = value[0] <= v <= value[1]
|
||||
elif op in ["<>", "!="]:
|
||||
matched = v != value
|
||||
elif op == "LIKE":
|
||||
if value and v:
|
||||
value = value.replace("%", "(.*)").replace("_", ".")
|
||||
matched = re.match("^" + value + "$", v)
|
||||
else:
|
||||
matched = False
|
||||
else:
|
||||
raise Exception("invalid select operator: '%s'" % op)
|
||||
return True if matched else False
|
||||
# Fields is None or list, maybe containing "*":
|
||||
if fields is None:
|
||||
fields = ["*"]
|
||||
@ -1920,55 +1965,43 @@ class DbWriteBase(DbReadBase):
|
||||
position = 0
|
||||
selected = 0
|
||||
result = Result()
|
||||
start_time = time.time()
|
||||
if filter:
|
||||
for handle in data:
|
||||
# have to evaluate all, because there is a filter
|
||||
item = self._tables[table]["handle_func"](handle)
|
||||
row = {}
|
||||
env = {}
|
||||
for field in filter.keys():
|
||||
# just the ones we need for filter
|
||||
value = item.get_field(field)
|
||||
env[hash_name(field)] = value
|
||||
matched = True
|
||||
for name, (op, value) in filter.items():
|
||||
v = eval(hash_name(name), env)
|
||||
if op == "=":
|
||||
matched = v == value
|
||||
elif op == ">":
|
||||
matched = v > value
|
||||
elif op == ">=":
|
||||
matched = v >= value
|
||||
elif op == "<":
|
||||
matched = v < value
|
||||
elif op == "<=":
|
||||
matched = v <= value
|
||||
elif op == "IN":
|
||||
matched = v in value
|
||||
elif op == "IS":
|
||||
matched = v is value
|
||||
elif op == "IS NOT":
|
||||
matched = v is not value
|
||||
elif op == "IS NULL":
|
||||
matched = v is None
|
||||
elif op == "IS NOT NULL":
|
||||
matched = v is not None
|
||||
elif op == "BETWEEN":
|
||||
matched = value[0] <= v <= value[1]
|
||||
elif op in ["<>", "!="]:
|
||||
matched = v != value
|
||||
elif op == "LIKE":
|
||||
value = value.replace("%", "(.*)").replace("_", ".")
|
||||
matched = re.match(value, v)
|
||||
else:
|
||||
raise Exception("invalid select operator: '%s'" % op)
|
||||
for (connector, exprs) in filter:
|
||||
for (name, op, value) in exprs:
|
||||
# just the ones we need for filter
|
||||
value = item.get_field(name, self, ignore_errors=True)
|
||||
env[hash_name(table, name)] = value
|
||||
for (connector, exprs) in filter:
|
||||
if connector == "AND":
|
||||
matched = True
|
||||
# all must match to be true
|
||||
for (name, op, value) in exprs:
|
||||
v = eval(hash_name(table, name), env)
|
||||
matched = compare(v, op, value)
|
||||
if not matched:
|
||||
break
|
||||
elif connector == "OR":
|
||||
matched = False
|
||||
# any must match to be true
|
||||
for (name, op, value) in exprs:
|
||||
v = eval(hash_name(table, name), env)
|
||||
matched = compare(v, op, value)
|
||||
if matched:
|
||||
break
|
||||
if not matched:
|
||||
break
|
||||
# else, keep going
|
||||
if matched:
|
||||
if selected < limit and start <= position:
|
||||
# now, we get all of the fields
|
||||
for field in fields:
|
||||
value = item.get_field(field)
|
||||
for field in fields:
|
||||
value = item.get_field(field, self, ignore_errors=True)
|
||||
row[field] = value
|
||||
selected += 1
|
||||
result.append(row)
|
||||
@ -1981,11 +2014,12 @@ class DbWriteBase(DbReadBase):
|
||||
break
|
||||
item = self._tables[table]["handle_func"](handle)
|
||||
row = {}
|
||||
for field in fields:
|
||||
value = item.get_field(field)
|
||||
for field in fields:
|
||||
value = item.get_field(field, self, ignore_errors=True)
|
||||
row[field] = value
|
||||
result.append(row)
|
||||
selected += 1
|
||||
position += 1
|
||||
result.total = self._tables[table]["count_func"]()
|
||||
result.time = time.time() - start_time
|
||||
return result
|
||||
|
@ -22,11 +22,35 @@ class HandleClass(str):
|
||||
def __init__(self, handle):
|
||||
super(HandleClass, self).__init__()
|
||||
|
||||
def Handle(classname, handle):
|
||||
def join(self, database, handle):
|
||||
return database._tables[self.classname]["handle_func"](handle)
|
||||
|
||||
@classmethod
|
||||
def get_schema(cls):
|
||||
from gramps.gen.lib import (Person, Family, Event, Place, Source,
|
||||
MediaObject, Repository, Note, Citation)
|
||||
tables = {
|
||||
"Person": Person,
|
||||
"Family": Family,
|
||||
"Event": Event,
|
||||
"Place": Place,
|
||||
"Source": Source,
|
||||
"MediaObject": MediaObject,
|
||||
"Repository": Repository,
|
||||
"Note": Note,
|
||||
"Citation": Citation,
|
||||
}
|
||||
return tables[cls.classname].get_schema()
|
||||
|
||||
def Handle(_classname, handle):
|
||||
if handle is None:
|
||||
return None
|
||||
h = HandleClass(handle)
|
||||
h.classname = classname
|
||||
class MyHandleClass(HandleClass):
|
||||
"""
|
||||
Class created to have classname attribute.
|
||||
"""
|
||||
classname = _classname
|
||||
h = MyHandleClass(handle)
|
||||
return h
|
||||
|
||||
def __from_struct(struct):
|
||||
|
@ -207,6 +207,18 @@ class Person(CitationBase, NoteBase, AttributeBase, MediaBase,
|
||||
for pr in self.person_ref_list] # 20
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def field_aliases(cls):
|
||||
"""
|
||||
Return dictionary of alias to full field names
|
||||
for this object class.
|
||||
"""
|
||||
return {
|
||||
"given": "primary_name.first_name",
|
||||
"surname": "primary_name.surname_list.0.surname",
|
||||
"surnames": "primary_name.surname_list.surname",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_labels(cls, _):
|
||||
return {
|
||||
|
@ -127,6 +127,21 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def field_aliases(cls):
|
||||
"""
|
||||
Return dictionary of alias to full field names
|
||||
for this object class.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_field_alias(cls, alias):
|
||||
"""
|
||||
Return full field name for an alias, if one.
|
||||
"""
|
||||
return cls.field_aliases().get(alias, alias)
|
||||
|
||||
@classmethod
|
||||
def get_schema(cls):
|
||||
"""
|
||||
@ -155,6 +170,7 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
|
||||
Get the associated label given a field name of this object.
|
||||
No index positions allowed on lists.
|
||||
"""
|
||||
field = cls.get_field_alias(field)
|
||||
chain = field.split(".")
|
||||
ftype = cls._follow_schema_path(chain)
|
||||
return ftype
|
||||
@ -172,55 +188,73 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
|
||||
elif part in schema.keys():
|
||||
path = schema[part]
|
||||
else:
|
||||
raise Exception("No such %s in %s" % (part, schema))
|
||||
raise Exception("No such '%s' in %s" % (part, list(schema.keys())))
|
||||
if isinstance(path, (list, tuple)):
|
||||
path = path[0]
|
||||
return path
|
||||
|
||||
def get_field(self, field):
|
||||
def get_field(self, field, db=None, ignore_errors=False):
|
||||
"""
|
||||
Get the value of a field.
|
||||
"""
|
||||
field = self.__class__.get_field_alias(field)
|
||||
chain = field.split(".")
|
||||
path = self._follow_field_path(chain)
|
||||
path = self._follow_field_path(chain, db, ignore_errors)
|
||||
return path
|
||||
|
||||
def _follow_field_path(self, chain, ignore_errors=False):
|
||||
def _follow_field_path(self, chain, db=None, ignore_errors=False):
|
||||
"""
|
||||
Follow a list of items. Return endpoint.
|
||||
With the db argument, can do joins across tables.
|
||||
self - current object
|
||||
"""
|
||||
path = self
|
||||
from .handle import HandleClass
|
||||
current = self
|
||||
path_to = []
|
||||
parent = self
|
||||
for part in chain:
|
||||
class_ = None
|
||||
if hasattr(path, part): # attribute
|
||||
path = getattr(path, part)
|
||||
path_to.append(part)
|
||||
if hasattr(current, part): # attribute
|
||||
current = getattr(current, part)
|
||||
elif part.isdigit(): # index into list
|
||||
path = path[int(part)]
|
||||
elif part.endswith(")"): # callable
|
||||
# parse
|
||||
function, sargs = part.split("(", 1)
|
||||
sargs = sargs[:-1] # remove right-parent
|
||||
# eval arguments
|
||||
args = []
|
||||
for sarg in sargs.split(","):
|
||||
if sarg:
|
||||
args.append(eval(sarg.strip()))
|
||||
# call
|
||||
path = getattr(path, function)(*args)
|
||||
elif ignore_errors:
|
||||
return
|
||||
else:
|
||||
raise Exception("%s is not a valid field of %s; use %s" %
|
||||
(part, path, dir(path)))
|
||||
return path
|
||||
current = current[int(part)]
|
||||
continue
|
||||
elif isinstance(current, (list, tuple)):
|
||||
current = [getattr(attr, part) for attr in current]
|
||||
else: # part not found on this self
|
||||
# current is a handle
|
||||
# part is something on joined object
|
||||
ptype = parent.__class__.get_field_type(".".join(path_to[:-1]))
|
||||
if isinstance(ptype, HandleClass):
|
||||
if db:
|
||||
# start over here:
|
||||
try:
|
||||
parent = ptype.join(db, current)
|
||||
current = getattr(parent, part)
|
||||
path_to = []
|
||||
continue
|
||||
except:
|
||||
if ignore_errors:
|
||||
return
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
raise Exception("Can't join without database")
|
||||
if ignore_errors:
|
||||
return
|
||||
else:
|
||||
raise Exception("%s is not a valid field of %s; use %s" %
|
||||
(part, current, dir(current)))
|
||||
return current
|
||||
|
||||
def set_field(self, field, value, ignore_errors=False):
|
||||
def set_field(self, field, value, db=None, ignore_errors=False):
|
||||
"""
|
||||
Set the value of a basic field (str, int, float, or bool).
|
||||
value can be a string or actual value.
|
||||
"""
|
||||
field = self.__class__.get_field_alias(field)
|
||||
chain = field.split(".")
|
||||
path = self._follow_field_path(chain[:-1], ignore_errors)
|
||||
path = self._follow_field_path(chain[:-1], db, ignore_errors)
|
||||
ftype = self.get_field_type(field)
|
||||
# ftype is str, bool, float, or int
|
||||
value = (value in ['True', True]) if ftype is bool else value
|
||||
|
Loading…
Reference in New Issue
Block a user