From 22fa6ed63af903adb31ad2c6f7908efdea471cc2 Mon Sep 17 00:00:00 2001 From: Doug Blank Date: Sun, 10 Jan 2016 10:55:45 -0500 Subject: [PATCH] Refinements in db.select; allow joins --- gramps/gen/db/base.py | 120 ++++++++++++++++++++++------------- gramps/gen/lib/handle.py | 30 ++++++++- gramps/gen/lib/person.py | 12 ++++ gramps/gen/lib/primaryobj.py | 90 ++++++++++++++++++-------- 4 files changed, 178 insertions(+), 74 deletions(-) diff --git a/gramps/gen/db/base.py b/gramps/gen/db/base.py index aedb8c832..8a8b8a684 100644 --- a/gramps/gen/db/base.py +++ b/gramps/gen/db/base.py @@ -31,6 +31,7 @@ from this class. # #------------------------------------------------------------------------- import re +import time #------------------------------------------------------------------------- # @@ -1887,23 +1888,67 @@ class DbWriteBase(DbReadBase): sort - use sort order (argument to DB.get_X_handles) start - position to start limit - count to get; -1 for all - filter - {field: (SQL string_operator, value), } + filter - [["AND", [(field, SQL string_operator, value), + (field, SQL string_operator, value), ...], + ["OR", [(field, SQL string_operator, value), + (field, SQL string_operator, value), ...]] handles all SQL except for NOT expression, eg NOT x = y """ class Result(list): """ - A list rows of just matching for this page, with total = all. + A list rows of just matching for this page, with total = all, + and time = time to select. """ total = 0 - def hash_name(name): + def hash_name(table, name): """ Used in filter to eval expressions involving selected data. """ + name = self._tables[table]["class_func"].get_field_alias(name) return (name .replace(".", "_D_") .replace("(", "_P_") .replace(")", "_P_")) + def compare(v, op, value): + """ + Compare values in a SQL-like way + """ + if op == "=": + matched = v == value + elif op == ">": + matched = v > value + elif op == ">=": + matched = v >= value + elif op == "<": + matched = v < value + elif op == "<=": + matched = v <= value + elif op == "IN": + matched = v in value + elif op == "NI": + matched = value in v + elif op == "IS": + matched = v is value + elif op == "IS NOT": + matched = v is not value + elif op == "IS NULL": + matched = v is None + elif op == "IS NOT NULL": + matched = v is not None + elif op == "BETWEEN": + matched = value[0] <= v <= value[1] + elif op in ["<>", "!="]: + matched = v != value + elif op == "LIKE": + if value and v: + value = value.replace("%", "(.*)").replace("_", ".") + matched = re.match("^" + value + "$", v) + else: + matched = False + else: + raise Exception("invalid select operator: '%s'" % op) + return True if matched else False # Fields is None or list, maybe containing "*": if fields is None: fields = ["*"] @@ -1920,55 +1965,43 @@ class DbWriteBase(DbReadBase): position = 0 selected = 0 result = Result() + start_time = time.time() if filter: for handle in data: # have to evaluate all, because there is a filter item = self._tables[table]["handle_func"](handle) row = {} env = {} - for field in filter.keys(): - # just the ones we need for filter - value = item.get_field(field) - env[hash_name(field)] = value - matched = True - for name, (op, value) in filter.items(): - v = eval(hash_name(name), env) - if op == "=": - matched = v == value - elif op == ">": - matched = v > value - elif op == ">=": - matched = v >= value - elif op == "<": - matched = v < value - elif op == "<=": - matched = v <= value - elif op == "IN": - matched = v in value - elif op == "IS": - matched = v is value - elif op == "IS NOT": - matched = v is not value - elif op == "IS NULL": - matched = v is None - elif op == "IS NOT NULL": - matched = v is not None - elif op == "BETWEEN": - matched = value[0] <= v <= value[1] - elif op in ["<>", "!="]: - matched = v != value - elif op == "LIKE": - value = value.replace("%", "(.*)").replace("_", ".") - matched = re.match(value, v) - else: - raise Exception("invalid select operator: '%s'" % op) + for (connector, exprs) in filter: + for (name, op, value) in exprs: + # just the ones we need for filter + value = item.get_field(name, self, ignore_errors=True) + env[hash_name(table, name)] = value + for (connector, exprs) in filter: + if connector == "AND": + matched = True + # all must match to be true + for (name, op, value) in exprs: + v = eval(hash_name(table, name), env) + matched = compare(v, op, value) + if not matched: + break + elif connector == "OR": + matched = False + # any must match to be true + for (name, op, value) in exprs: + v = eval(hash_name(table, name), env) + matched = compare(v, op, value) + if matched: + break if not matched: break + # else, keep going if matched: if selected < limit and start <= position: # now, we get all of the fields - for field in fields: - value = item.get_field(field) + for field in fields: + value = item.get_field(field, self, ignore_errors=True) row[field] = value selected += 1 result.append(row) @@ -1981,11 +2014,12 @@ class DbWriteBase(DbReadBase): break item = self._tables[table]["handle_func"](handle) row = {} - for field in fields: - value = item.get_field(field) + for field in fields: + value = item.get_field(field, self, ignore_errors=True) row[field] = value result.append(row) selected += 1 position += 1 result.total = self._tables[table]["count_func"]() + result.time = time.time() - start_time return result diff --git a/gramps/gen/lib/handle.py b/gramps/gen/lib/handle.py index 2fb003a3e..2e3b37ed5 100644 --- a/gramps/gen/lib/handle.py +++ b/gramps/gen/lib/handle.py @@ -22,11 +22,35 @@ class HandleClass(str): def __init__(self, handle): super(HandleClass, self).__init__() -def Handle(classname, handle): + def join(self, database, handle): + return database._tables[self.classname]["handle_func"](handle) + + @classmethod + def get_schema(cls): + from gramps.gen.lib import (Person, Family, Event, Place, Source, + MediaObject, Repository, Note, Citation) + tables = { + "Person": Person, + "Family": Family, + "Event": Event, + "Place": Place, + "Source": Source, + "MediaObject": MediaObject, + "Repository": Repository, + "Note": Note, + "Citation": Citation, + } + return tables[cls.classname].get_schema() + +def Handle(_classname, handle): if handle is None: return None - h = HandleClass(handle) - h.classname = classname + class MyHandleClass(HandleClass): + """ + Class created to have classname attribute. + """ + classname = _classname + h = MyHandleClass(handle) return h def __from_struct(struct): diff --git a/gramps/gen/lib/person.py b/gramps/gen/lib/person.py index c157b6a9e..c6c871ae3 100644 --- a/gramps/gen/lib/person.py +++ b/gramps/gen/lib/person.py @@ -207,6 +207,18 @@ class Person(CitationBase, NoteBase, AttributeBase, MediaBase, for pr in self.person_ref_list] # 20 } + @classmethod + def field_aliases(cls): + """ + Return dictionary of alias to full field names + for this object class. + """ + return { + "given": "primary_name.first_name", + "surname": "primary_name.surname_list.0.surname", + "surnames": "primary_name.surname_list.surname", + } + @classmethod def get_labels(cls, _): return { diff --git a/gramps/gen/lib/primaryobj.py b/gramps/gen/lib/primaryobj.py index 622c5275d..20c03e182 100644 --- a/gramps/gen/lib/primaryobj.py +++ b/gramps/gen/lib/primaryobj.py @@ -127,6 +127,21 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase): """ return {} + @classmethod + def field_aliases(cls): + """ + Return dictionary of alias to full field names + for this object class. + """ + return {} + + @classmethod + def get_field_alias(cls, alias): + """ + Return full field name for an alias, if one. + """ + return cls.field_aliases().get(alias, alias) + @classmethod def get_schema(cls): """ @@ -155,6 +170,7 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase): Get the associated label given a field name of this object. No index positions allowed on lists. """ + field = cls.get_field_alias(field) chain = field.split(".") ftype = cls._follow_schema_path(chain) return ftype @@ -172,55 +188,73 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase): elif part in schema.keys(): path = schema[part] else: - raise Exception("No such %s in %s" % (part, schema)) + raise Exception("No such '%s' in %s" % (part, list(schema.keys()))) if isinstance(path, (list, tuple)): path = path[0] return path - def get_field(self, field): + def get_field(self, field, db=None, ignore_errors=False): """ Get the value of a field. """ + field = self.__class__.get_field_alias(field) chain = field.split(".") - path = self._follow_field_path(chain) + path = self._follow_field_path(chain, db, ignore_errors) return path - def _follow_field_path(self, chain, ignore_errors=False): + def _follow_field_path(self, chain, db=None, ignore_errors=False): """ Follow a list of items. Return endpoint. + With the db argument, can do joins across tables. + self - current object """ - path = self + from .handle import HandleClass + current = self + path_to = [] + parent = self for part in chain: - class_ = None - if hasattr(path, part): # attribute - path = getattr(path, part) + path_to.append(part) + if hasattr(current, part): # attribute + current = getattr(current, part) elif part.isdigit(): # index into list - path = path[int(part)] - elif part.endswith(")"): # callable - # parse - function, sargs = part.split("(", 1) - sargs = sargs[:-1] # remove right-parent - # eval arguments - args = [] - for sarg in sargs.split(","): - if sarg: - args.append(eval(sarg.strip())) - # call - path = getattr(path, function)(*args) - elif ignore_errors: - return - else: - raise Exception("%s is not a valid field of %s; use %s" % - (part, path, dir(path))) - return path + current = current[int(part)] + continue + elif isinstance(current, (list, tuple)): + current = [getattr(attr, part) for attr in current] + else: # part not found on this self + # current is a handle + # part is something on joined object + ptype = parent.__class__.get_field_type(".".join(path_to[:-1])) + if isinstance(ptype, HandleClass): + if db: + # start over here: + try: + parent = ptype.join(db, current) + current = getattr(parent, part) + path_to = [] + continue + except: + if ignore_errors: + return + else: + raise + else: + raise Exception("Can't join without database") + if ignore_errors: + return + else: + raise Exception("%s is not a valid field of %s; use %s" % + (part, current, dir(current))) + return current - def set_field(self, field, value, ignore_errors=False): + def set_field(self, field, value, db=None, ignore_errors=False): """ Set the value of a basic field (str, int, float, or bool). value can be a string or actual value. """ + field = self.__class__.get_field_alias(field) chain = field.split(".") - path = self._follow_field_path(chain[:-1], ignore_errors) + path = self._follow_field_path(chain[:-1], db, ignore_errors) ftype = self.get_field_type(field) # ftype is str, bool, float, or int value = (value in ['True', True]) if ftype is bool else value