Hide db._select(); Refactor db._tables; QuerySet; tests and bug fixes

db.select() has become db._select()

This commit refactors database._tables so that they can work properly
after a proxy or filter has been removed.

It adds abstract API called QuerySet which allows a variety of
selections and ordering of data.

Logging was added to sqlite to better see db access.

Additional fixes:

* clearing GenericDB._directory
* test_util_test off-by-one error
This commit is contained in:
Doug Blank 2016-02-28 15:40:16 -05:00
parent 0415ac8eab
commit ad3dcb8d13
20 changed files with 1645 additions and 518 deletions

View File

@ -31,7 +31,7 @@ script:
- mkdir -p ~/.gramps/grampsdb/ - mkdir -p ~/.gramps/grampsdb/
# --exclude=TestUser because of older version of mock # --exclude=TestUser because of older version of mock
# without configure_mock # without configure_mock
- nosetests3 --with-coverage --cover-package=gramps --exclude=TestcaseGenerator --exclude=vcard --exclude=merge_ref_test --exclude=user_test gramps - nosetests3 --nologcapture --with-coverage --cover-package=gramps --exclude=TestcaseGenerator --exclude=vcard --exclude=merge_ref_test --exclude=user_test gramps
after_success: after_success:
- codecov - codecov

View File

@ -3,7 +3,7 @@
"http://gramps-project.org/xml/1.7.1/grampsxml.dtd"> "http://gramps-project.org/xml/1.7.1/grampsxml.dtd">
<database xmlns="http://gramps-project.org/xml/1.7.1/"> <database xmlns="http://gramps-project.org/xml/1.7.1/">
<header> <header>
<created date="2015-08-01" version="5.0.0"/> <created date="2016-02-06" version="5.0.0"/>
<researcher> <researcher>
<resname>Alex Roitman</resname> <resname>Alex Roitman</resname>
<resaddr>1122 Boogie Boogie Ave</resaddr> <resaddr>1122 Boogie Boogie Ave</resaddr>
@ -901,7 +901,7 @@
<objref hlink="_HHNT6D73QPKC0KWK2Y"/> <objref hlink="_HHNT6D73QPKC0KWK2Y"/>
<childof hlink="_S7MT6D1JSGX9PZO27F"/> <childof hlink="_S7MT6D1JSGX9PZO27F"/>
</person> </person>
<person handle="_PSNT6D0DDHJOBCFJWX" change="1370206720" id="I0037"> <person handle="_PSNT6D0DDHJOBCFJWX" change="1370206720" id="I0037" priv="1">
<gender>M</gender> <gender>M</gender>
<name type="Birth Name"> <name type="Birth Name">
<first>Edwin Michael</first> <first>Edwin Michael</first>

View File

@ -32,6 +32,7 @@ from this class.
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
import re import re
import time import time
from operator import itemgetter
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
# #
@ -45,6 +46,35 @@ from ..lib.childref import ChildRef
from .txn import DbTxn from .txn import DbTxn
from .exceptions import DbTransactionCancel from .exceptions import DbTransactionCancel
def eval_order_by(order_by, obj, db):
"""
Given a list of [[field, DIRECTION], ...]
return the list of values of the fields
"""
values = []
for (field, direction) in order_by:
values.append(obj.get_field(field, db, ignore_errors=True))
return values
def sort_objects(objects, order_by, db):
"""
Python-based sorting.
"""
# first build sort order:
sorted_items = []
map_items = {}
for obj in objects:
# just use values and handle to keep small:
sorted_items.append((eval_order_by(order_by, obj, db), obj.handle))
map_items[obj.handle] = obj
# next we sort by fields and direction
pos = len(order_by) - 1
for (field, order) in reversed(order_by): # sort the lasts parts first
sorted_items.sort(key=itemgetter(pos), reverse=(order=="DESC"))
pos -= 1
for (order_by_values, handle) in sorted_items:
yield map_items[handle]
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
# #
# Gramps libraries # Gramps libraries
@ -67,18 +97,12 @@ class DbReadBase(object):
""" """
self.basedb = self self.basedb = self
self.__feature = {} # {"feature": VALUE, ...} self.__feature = {} # {"feature": VALUE, ...}
self._tables = {
"Citation": {}, def get_table_func(self, table=None, func=None):
"Event": {}, """
"Family": {}, Base implementation of get_table_func.
"Media": {}, """
"Note": {}, return None
"Person": {},
"Place": {},
"Repository": {},
"Source": {},
"Tag": {},
}
def get_feature(self, feature): def get_feature(self, feature):
""" """
@ -1234,6 +1258,186 @@ class DbReadBase(object):
""" """
raise NotImplementedError raise NotImplementedError
def _select(self, table, fields=None, start=0, limit=-1,
where=None, order_by=None):
"""
Default implementation of a select for those databases
that don't support SQL. Returns a list of dicts, total,
and time.
table - Person, Family, etc.
fields - used by object.get_field()
start - position to start
limit - count to get; -1 for all
where - (field, SQL string_operator, value) |
["AND", [where, where, ...]] |
["OR", [where, where, ...]] |
["NOT", where]
order_by - [[fieldname, "ASC" | "DESC"], ...]
"""
def compare(v, op, value):
"""
Compare values in a SQL-like way
"""
if isinstance(v, (list, tuple)) and len(v) > 0: # join, or multi-values
# If any is true:
for item in v:
if compare(item, op, value):
return True
return False
if op == "=":
matched = v == value
elif op == ">":
matched = v > value
elif op == ">=":
matched = v >= value
elif op == "<":
matched = v < value
elif op == "<=":
matched = v <= value
elif op == "IN":
matched = v in value
elif op == "IS":
matched = v is value
elif op == "IS NOT":
matched = v is not value
elif op == "IS NULL":
matched = v is None
elif op == "IS NOT NULL":
matched = v is not None
elif op == "BETWEEN":
matched = value[0] <= v <= value[1]
elif op in ["<>", "!="]:
matched = v != value
elif op == "LIKE":
if value and v:
value = value.replace("%", "(.*)").replace("_", ".")
## FIXME: allow a case-insensitive version
matched = re.match("^" + value + "$", v, re.MULTILINE)
else:
matched = False
else:
raise Exception("invalid select operator: '%s'" % op)
return True if matched else False
def evaluate_values(condition, item, db, table, env):
"""
Evaluates the names in all conditions.
"""
if len(condition) == 2: # ["AND" [...]] | ["OR" [...]] | ["NOT" expr]
connector, exprs = condition
if connector in ["AND", "OR"]:
for expr in exprs:
evaluate_values(expr, item, db, table, env)
else: # "NOT"
evaluate_values(exprs, item, db, table, env)
elif len(condition) == 3: # (name, op, value)
(name, op, value) = condition
# just the ones we need for where
hname = self._hash_name(table, name)
if hname not in env:
value = item.get_field(name, db, ignore_errors=True)
env[hname] = value
def evaluate_truth(condition, item, db, table, env):
if len(condition) == 2: # ["AND"|"OR" [...]]
connector, exprs = condition
if connector == "AND": # all must be true
for expr in exprs:
if not evaluate_truth(expr, item, db, table, env):
return False
return True
elif connector == "OR": # any will return true
for expr in exprs:
if evaluate_truth(expr, item, db, table, env):
return True
return False
elif connector == "NOT": # return not of single value
return not evaluate_truth(exprs, item, db, table, env)
else:
raise Exception("No such connector: '%s'" % connector)
elif len(condition) == 3: # (name, op, value)
(name, op, value) = condition
v = env.get(self._hash_name(table, name))
return compare(v, op, value)
# Fields is None or list, maybe containing "*":
if fields is None:
pass # ok
elif not isinstance(fields, (list, tuple)):
raise Exception("fields must be a list/tuple of field names")
elif "*" in fields:
fields.remove("*")
fields.extend(self.get_table_func(table,"class_func").get_schema().keys())
get_count_only = (fields is not None and fields[0] == "count(1)")
position = 0
selected = 0
if get_count_only:
if where or limit != -1 or start != 0:
# no need to order for a count
data = self.get_table_func(table,"iter_func")()
else:
yield self.get_table_func(table,"count_func")()
else:
data = self.get_table_func(table, "iter_func")(order_by=order_by)
if where:
for item in data:
# Go through all fliters and evaluate the fields:
env = {}
evaluate_values(where, item, self, table, env)
matched = evaluate_truth(where, item, self, table, env)
if matched:
if ((selected < limit) or (limit == -1)) and start <= position:
selected += 1
if not get_count_only:
if fields:
row = {}
for field in fields:
value = item.get_field(field, self, ignore_errors=True)
row[field.replace("__", ".")] = value
yield row
else:
yield item
position += 1
if get_count_only:
yield selected
else: # no where
for item in data:
if position >= start:
if ((selected >= limit) and (limit != -1)):
break
selected += 1
if not get_count_only:
if fields:
row = {}
for field in fields:
value = item.get_field(field, self, ignore_errors=True)
row[field.replace("__", ".")] = value
yield row
else:
yield item
position += 1
if get_count_only:
yield selected
def _hash_name(self, table, name):
"""
Used in SQL functions to eval expressions involving selected
data.
"""
name = self.get_table_func(table,"class_func").get_field_alias(name)
return name.replace(".", "__")
Person = property(lambda self:QuerySet(self, "Person"))
Family = property(lambda self:QuerySet(self, "Family"))
Note = property(lambda self:QuerySet(self, "Note"))
Citation = property(lambda self:QuerySet(self, "Citation"))
Source = property(lambda self:QuerySet(self, "Source"))
Repository = property(lambda self:QuerySet(self, "Repository"))
Place = property(lambda self:QuerySet(self, "Place"))
Event = property(lambda self:QuerySet(self, "Event"))
Tag = property(lambda self:QuerySet(self, "Tag"))
class DbWriteBase(DbReadBase): class DbWriteBase(DbReadBase):
""" """
Gramps database object. This object is a base class for all Gramps database object. This object is a base class for all
@ -1877,180 +2081,279 @@ class DbWriteBase(DbReadBase):
else: else:
raise ValueError("invalid instance type: %s" % instance.__class__.__name__) raise ValueError("invalid instance type: %s" % instance.__class__.__name__)
def select(self, table, fields=None, start=0, limit=-1, def get_queryset_by_table_name(self, table_name):
where=None, order_by=None):
""" """
Default implementation of a select for those databases Get Person, Family queryset by name.
that don't support SQL. Returns a list of dicts, total, """
and time. return getattr(self, table_name)
table - Person, Family, etc. class Operator(object):
fields - used by object.get_field() """
start - position to start Base for QuerySet operators.
limit - count to get; -1 for all """
where - (field, SQL string_operator, value) | op = "OP"
["AND", [where, where, ...]] | def __init__(self, *expressions, **kwargs):
["OR", [where, where, ...]] | if self.op in ["AND", "OR"]:
["NOT", where] exprs = [expression.list for expression
order_by - [[fieldname, "ASC" | "DESC"], ...] in expressions]
""" for key in kwargs:
class Result(list): exprs.append(
""" _select_field_operator_value(key, "=", kwargs[key]))
A list rows of just matching for this page, with total = all, else: # "NOT"
time = time to select, expanded (unpickled), query (N/A). if expressions:
""" exprs = expressions.list
total = 0
time = 0.0
expanded = True
query = None
def compare(v, op, value):
"""
Compare values in a SQL-like way
"""
if isinstance(v, (list, tuple)): # join, or multi-values
# If any is true:
for item in v:
if compare(item, op, value):
return True
return False
if op == "=":
matched = v == value
elif op == ">":
matched = v > value
elif op == ">=":
matched = v >= value
elif op == "<":
matched = v < value
elif op == "<=":
matched = v <= value
elif op == "IN":
matched = v in value
elif op == "IS":
matched = v is value
elif op == "IS NOT":
matched = v is not value
elif op == "IS NULL":
matched = v is None
elif op == "IS NOT NULL":
matched = v is not None
elif op == "BETWEEN":
matched = value[0] <= v <= value[1]
elif op in ["<>", "!="]:
matched = v != value
elif op == "LIKE":
if value and v:
value = value.replace("%", "(.*)").replace("_", ".")
## FIXME: allow a case-insensitive version
matched = re.match("^" + value + "$", v, re.MULTILINE)
else:
matched = False
else: else:
raise Exception("invalid select operator: '%s'" % op) key, value = list(kwargs.items())[0]
return True if matched else False exprs = _select_field_operator_value(key, "=", value)
self.list = [self.op, exprs]
def evaluate_values(condition, item, db, table, env): class AND(Operator):
""" op = "AND"
Evaluates the names in all conditions.
"""
if len(condition) == 2: # ["AND" [...]] | ["OR" [...]] | ["NOT" expr]
connector, exprs = condition
if connector in ["AND", "OR"]:
for expr in exprs:
evaluate_values(expr, item, db, table, env)
else: # "NOT"
evaluate_values(exprs, item, db, table, env)
elif len(condition) == 3: # (name, op, value)
(name, op, value) = condition
# just the ones we need for where
hname = self._hash_name(table, name)
if hname not in env:
value = item.get_field(name, db, ignore_errors=True)
env[hname] = value
def evaluate_truth(condition, item, db, table, env): class OR(Operator):
if len(condition) == 2: # ["AND"|"OR" [...]] """
connector, exprs = condition OR operator for QuerySet logical WHERE expressions.
if connector == "AND": # all must be true """
for expr in exprs: op = "OR"
if not evaluate_truth(expr, item, db, table, env):
return False class NOT(Operator):
return True """
elif connector == "OR": # any will return true NOT operator for QuerySet logical WHERE expressions.
for expr in exprs: """
if evaluate_truth(expr, item, db, table, env): op = "NOT"
return True
return False class QuerySet(object):
elif connector == "NOT": # return not of single value """
return not evaluate_truth(exprs, item, db, table, env) A container for selection criteria before being actually
applied to a database.
"""
def __init__(self, database, table):
self.database = database
self.table = table
self.generator = None
self.where_by = None
self.order_by = None
self.limit_by = -1
self.start = 0
self.needs_to_run = False
def limit(self, start=None, count=None):
"""
Put limits on the selection.
"""
if start is not None:
self.start = start
if count is not None:
self.limit_by = count
self.needs_to_run = True
return self
def order(self, *args):
"""
Put an ordering on the selection.
"""
for arg in args:
if self.order_by is None:
self.order_by = []
if arg.startswith("-"):
self.order_by.append((arg[1:], "DESC"))
else:
self.order_by.append((arg, "ASC"))
self.needs_to_run = True
return self
def _add_where_clause(self, *args, **kwargs):
"""
Add a condition to the where clause.
"""
# First, handle AND, OR, NOT args:
and_expr = []
for arg in args:
expr = arg.list
and_expr.append(expr)
# Next, handle kwargs:
for keyword in kwargs:
and_expr.append(
_select_field_operator_value(
keyword, "=", kwargs[keyword]))
if and_expr:
if self.where_by:
self.where_by = ["AND", [self.where_by] + and_expr]
elif len(and_expr) == 1:
self.where_by = and_expr[0]
else:
self.where_by = ["AND", and_expr]
self.needs_to_run = True
return self
def count(self):
"""
Run query with just where, start, limit to get count.
"""
if self.generator and self.needs_to_run:
raise Exception("Queries in invalid order")
elif self.generator:
return len(list(self.generator))
else:
generator = self.database._select(self.table,
["count(1)"],
where=self.where_by,
start=self.start,
limit=self.limit_by)
return next(generator)
def _generate(self, args=None):
"""
Create a generator from current options.
"""
generator = self.database._select(self.table,
args,
order_by=self.order_by,
where=self.where_by,
start=self.start,
limit=self.limit_by)
# Reset all criteria
self.where_by = None
self.order_by = None
self.limit_by = -1
self.start = 0
self.needs_to_run = False
return generator
def select(self, *args):
"""
Actually touch the database.
"""
if len(args) == 0:
args = None
if self.generator and self.needs_to_run:
## problem
raise Exception("Queries in invalid order")
elif self.generator:
if args: # there is a generator, with args
for i in self.generator:
yield [i.get_field(arg) for arg in args]
else: # generator, no args
for i in self.generator:
yield i
else: # need to run or not
self.generator = self._generate(args)
for i in self.generator:
yield i
def proxy(self, proxy_name, *args, **kwargs):
"""
Apply a named proxy to the db.
"""
from gramps.gen.proxy import (LivingProxyDb, PrivateProxyDb,
ReferencedBySelectionProxyDb)
if proxy_name == "living":
proxy_class = LivingProxyDb
elif proxy_name == "private":
proxy_class = PrivateProxyDb
elif proxy_name == "referenced":
proxy_class = ReferencedBySelectionProxyDb
else:
raise Exception("No such proxy name: '%s'" % proxy_name)
self.database = proxy_class(self.database, *args, **kwargs)
return self
def filter(self, *args, **kwargs):
"""
Apply a filter to the database.
"""
from gramps.gen.proxy import FilterProxyDb
from gramps.gen.filters import GenericFilter
for i in range(len(args)):
arg = args[i]
if isinstance(arg, GenericFilter):
self.database = FilterProxyDb(self.database, arg, *args[i+1:])
if arg.where_by:
self._add_where_clause(arg.where_by)
elif isinstance(arg, Operator):
self._add_where_clause(arg)
elif callable(arg):
if self.generator and self.needs_to_run:
## error
raise Exception("Queries in invalid order")
elif self.generator:
pass # ok
else: else:
raise Exception("No such connector: '%s'" % connector) self.generator = self._generate()
elif len(condition) == 3: # (name, op, value) self.generator = filter(arg, self.generator)
(name, op, value) = condition else:
v = env.get(self._hash_name(table, name)) pass # ignore, may have been arg from previous Filter
return compare(v, op, value) if kwargs:
self._add_where_clause(**kwargs)
return self
# Fields is None or list, maybe containing "*": def map(self, f):
if fields is None: """
fields = ["*"] Apply the function f to the selected items and return results.
elif not isinstance(fields, (list, tuple)): """
raise Exception("fields must be a list/tuple of field names") if self.generator and self.needs_to_run:
if "*" in fields: raise Exception("Queries in invalid order")
fields.remove("*") elif self.generator:
fields.extend(self._tables[table]["class_func"].get_schema().keys()) pass # ok
data = self._tables[table]["iter_func"](order_by=order_by) else:
position = 0 self.generator = self._generate()
selected = 0 previous_generator = self.generator
result = Result() def generator():
start_time = time.time() for item in previous_generator:
if where: yield f(item)
for item in data: self.generator = generator()
# have to evaluate all, because there is a where return self
row = {}
env = {}
# Go through all fliters and evaluate the fields:
evaluate_values(where, item, self, table, env)
matched = evaluate_truth(where, item, self, table, env)
if matched:
if ((selected < limit) or (limit == -1)) and start <= position:
# now, we get all of the fields
for field in fields:
value = item.get_field(field, self, ignore_errors=True)
row[field.replace("__", ".")] = value
selected += 1
result.append(row)
position += 1
result.total = position
else: # no where
for item in data:
if position >= start:
if ((selected >= limit) and (limit != -1)):
break
row = {}
for field in fields:
value = item.get_field(field, self, ignore_errors=True)
row[field.replace("__", ".")] = value
result.append(row)
selected += 1
position += 1
result.total = self._tables[table]["count_func"]()
result.time = time.time() - start_time
return result
def _hash_name(self, table, name): def tag(self, tag_text):
""" """
Used in SQL functions to eval expressions involving selected Tag the selected items with the tag name.
data.
""" """
name = self._tables[table]["class_func"].get_field_alias(name) if self.generator and self.needs_to_run:
return name.replace(".", "__") raise Exception("Queries in invalid order")
elif self.generator:
pass # ok
else:
self.generator = self._generate()
tag = self.database.get_tag_from_name(tag_text)
trans_class = self.database.get_transaction_class()
with trans_class("Tag Selected Items", self.database, batch=True) as trans:
if tag is None:
tag = self.database.get_table_func("Tag","class_func")()
tag.set_name(tag_text)
self.database.add_tag(tag, trans)
commit_func = self.database.get_table_func(self.table,"commit_func")
for item in self.generator:
if tag.handle not in item.tag_list:
item.add_tag(tag.handle)
commit_func(item, trans)
def eval_order_by(self, order_by, obj): def _to_dot_format(field):
""" """
Given a list of [[field, DIRECTION], ...] Convert a field keyword arg into a proper
return the list of values of the fields dotted field name.
""" """
values = [] return field.replace("__", ".")
for (field, direction) in order_by:
values.append(obj.get_field(field, self, ignore_errors=True))
return values
def _select_field_operator_value(field, op, value):
"""
Convert a field keyword arg into proper
field, op, and value.
"""
alias = {
"LT": "<",
"GT": ">",
"LTE": "<=",
"GTE": ">=",
"IS_NOT": "IS NOT",
"IS_NULL": "IS NULL",
"IS_NOT_NULL": "IS NOT NULL",
"NE": "<>",
}
for operator in ["LIKE", "IN"] + list(alias.keys()):
operator = "__" + operator
if field.endswith(operator):
op = field[-len(operator) + 2:]
field = field[:-len(operator)]
op = alias.get(op, op)
field = _to_dot_format(field)
return (field, op, value)

View File

@ -46,20 +46,15 @@ from gramps.gen.const import GRAMPS_LOCALE as glocale
_ = glocale.translation.gettext _ = glocale.translation.gettext
from gramps.gen.db import (DbReadBase, DbWriteBase, DbTxn, DbUndo, from gramps.gen.db import (DbReadBase, DbWriteBase, DbTxn, DbUndo,
KEY_TO_NAME_MAP, KEY_TO_CLASS_MAP, KEY_TO_NAME_MAP, KEY_TO_CLASS_MAP,
CLASS_TO_KEY_MAP, TXNADD, TXNUPD, TXNDEL) CLASS_TO_KEY_MAP, TXNADD, TXNUPD, TXNDEL,
PERSON_KEY, FAMILY_KEY, CITATION_KEY,
SOURCE_KEY, EVENT_KEY, MEDIA_KEY,
PLACE_KEY, REPOSITORY_KEY, NOTE_KEY,
TAG_KEY, eval_order_by)
from gramps.gen.db.base import QuerySet
from gramps.gen.utils.callback import Callback from gramps.gen.utils.callback import Callback
from gramps.gen.updatecallback import UpdateCallback from gramps.gen.updatecallback import UpdateCallback
from gramps.gen.db.dbconst import * from gramps.gen.db.dbconst import *
from gramps.gen.db import (PERSON_KEY,
FAMILY_KEY,
CITATION_KEY,
SOURCE_KEY,
EVENT_KEY,
MEDIA_KEY,
PLACE_KEY,
REPOSITORY_KEY,
NOTE_KEY,
TAG_KEY)
from gramps.gen.utils.id import create_id from gramps.gen.utils.id import create_id
from gramps.gen.lib.researcher import Researcher from gramps.gen.lib.researcher import Researcher
@ -227,7 +222,7 @@ class Table(object):
if funcs: if funcs:
self.funcs = funcs self.funcs = funcs
else: else:
self.funcs = db._tables[table_name] self.funcs = db.get_table_func(table_name)
def cursor(self): def cursor(self):
""" """
@ -439,7 +434,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
DbReadBase.__init__(self) DbReadBase.__init__(self)
DbWriteBase.__init__(self) DbWriteBase.__init__(self)
Callback.__init__(self) Callback.__init__(self)
self._tables['Person'].update( self.__tables = {
'Person':
{ {
"handle_func": self.get_person_from_handle, "handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id, "gramps_id_func": self.get_person_from_gramps_id,
@ -456,8 +452,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_person_data, "raw_func": self._get_raw_person_data,
"raw_id_func": self._get_raw_person_from_id_data, "raw_id_func": self._get_raw_person_from_id_data,
"del_func": self.remove_person, "del_func": self.remove_person,
}) },
self._tables['Family'].update( 'Family':
{ {
"handle_func": self.get_family_from_handle, "handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id, "gramps_id_func": self.get_family_from_gramps_id,
@ -474,8 +470,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_family_data, "raw_func": self._get_raw_family_data,
"raw_id_func": self._get_raw_family_from_id_data, "raw_id_func": self._get_raw_family_from_id_data,
"del_func": self.remove_family, "del_func": self.remove_family,
}) },
self._tables['Source'].update( 'Source':
{ {
"handle_func": self.get_source_from_handle, "handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id, "gramps_id_func": self.get_source_from_gramps_id,
@ -492,8 +488,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_source_data, "raw_func": self._get_raw_source_data,
"raw_id_func": self._get_raw_source_from_id_data, "raw_id_func": self._get_raw_source_from_id_data,
"del_func": self.remove_source, "del_func": self.remove_source,
}) },
self._tables['Citation'].update( 'Citation':
{ {
"handle_func": self.get_citation_from_handle, "handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id, "gramps_id_func": self.get_citation_from_gramps_id,
@ -510,8 +506,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_citation_data, "raw_func": self._get_raw_citation_data,
"raw_id_func": self._get_raw_citation_from_id_data, "raw_id_func": self._get_raw_citation_from_id_data,
"del_func": self.remove_citation, "del_func": self.remove_citation,
}) },
self._tables['Event'].update( 'Event':
{ {
"handle_func": self.get_event_from_handle, "handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id, "gramps_id_func": self.get_event_from_gramps_id,
@ -528,8 +524,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_event_data, "raw_func": self._get_raw_event_data,
"raw_id_func": self._get_raw_event_from_id_data, "raw_id_func": self._get_raw_event_from_id_data,
"del_func": self.remove_event, "del_func": self.remove_event,
}) },
self._tables['Media'].update( 'Media':
{ {
"handle_func": self.get_media_from_handle, "handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id, "gramps_id_func": self.get_media_from_gramps_id,
@ -546,8 +542,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_media_data, "raw_func": self._get_raw_media_data,
"raw_id_func": self._get_raw_media_from_id_data, "raw_id_func": self._get_raw_media_from_id_data,
"del_func": self.remove_media, "del_func": self.remove_media,
}) },
self._tables['Place'].update( 'Place':
{ {
"handle_func": self.get_place_from_handle, "handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id, "gramps_id_func": self.get_place_from_gramps_id,
@ -564,8 +560,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_place_data, "raw_func": self._get_raw_place_data,
"raw_id_func": self._get_raw_place_from_id_data, "raw_id_func": self._get_raw_place_from_id_data,
"del_func": self.remove_place, "del_func": self.remove_place,
}) },
self._tables['Repository'].update( 'Repository':
{ {
"handle_func": self.get_repository_from_handle, "handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id, "gramps_id_func": self.get_repository_from_gramps_id,
@ -582,8 +578,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_repository_data, "raw_func": self._get_raw_repository_data,
"raw_id_func": self._get_raw_repository_from_id_data, "raw_id_func": self._get_raw_repository_from_id_data,
"del_func": self.remove_repository, "del_func": self.remove_repository,
}) },
self._tables['Note'].update( 'Note':
{ {
"handle_func": self.get_note_from_handle, "handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id, "gramps_id_func": self.get_note_from_gramps_id,
@ -600,8 +596,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"raw_func": self._get_raw_note_data, "raw_func": self._get_raw_note_data,
"raw_id_func": self._get_raw_note_from_id_data, "raw_id_func": self._get_raw_note_from_id_data,
"del_func": self.remove_note, "del_func": self.remove_note,
}) },
self._tables['Tag'].update( 'Tag':
{ {
"handle_func": self.get_tag_from_handle, "handle_func": self.get_tag_from_handle,
"gramps_id_func": None, "gramps_id_func": None,
@ -615,7 +611,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
"count_func": self.get_number_of_tags, "count_func": self.get_number_of_tags,
"raw_func": self._get_raw_tag_data, "raw_func": self._get_raw_tag_data,
"del_func": self.remove_tag, "del_func": self.remove_tag,
}) }
}
self.set_save_path(directory) self.set_save_path(directory)
# skip GEDCOM cross-ref check for now: # skip GEDCOM cross-ref check for now:
self.set_feature("skip-check-xref", True) self.set_feature("skip-check-xref", True)
@ -725,6 +722,19 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
if directory: if directory:
self.load(directory) self.load(directory)
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table] # dict of functions
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def load(self, directory, callback=None, mode=None, def load(self, directory, callback=None, mode=None,
force_schema_upgrade=False, force_schema_upgrade=False,
force_bsddb_upgrade=False, force_bsddb_upgrade=False,
@ -801,12 +811,12 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
def get_table_names(self): def get_table_names(self):
"""Return a list of valid table names.""" """Return a list of valid table names."""
return list(self._tables.keys()) return list(self.get_table_func())
def get_table_metadata(self, table_name): def get_table_metadata(self, table_name):
"""Return the metadata for a valid table name.""" """Return the metadata for a valid table name."""
if table_name in self._tables: if table_name in self.get_table_func():
return self._tables[table_name] return self.get_table_func(table_name)
return None return None
def transaction_begin(self, transaction): def transaction_begin(self, transaction):
@ -1154,7 +1164,7 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
Iterate over items in a class, possibly ordered by Iterate over items in a class, possibly ordered by
a list of field names and direction ("ASC" or "DESC"). a list of field names and direction ("ASC" or "DESC").
""" """
cursor = self._tables[class_.__name__]["cursor_func"] cursor = self.get_table_func(class_.__name__,"cursor_func")
if order_by is None: if order_by is None:
for data in cursor(): for data in cursor():
yield class_.create(data[1]) yield class_.create(data[1])
@ -1164,7 +1174,7 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
for data in cursor(): for data in cursor():
obj = class_.create(data[1]) obj = class_.create(data[1])
# just use values and handle to keep small: # just use values and handle to keep small:
sorted_items.append((self.eval_order_by(order_by, obj), obj.handle)) sorted_items.append((eval_order_by(order_by, obj, self), obj.handle))
# next we sort by fields and direction # next we sort by fields and direction
def getitem(item, pos): def getitem(item, pos):
sort_items = item[0] sort_items = item[0]
@ -1173,17 +1183,17 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
elif sort_items[pos] is None: elif sort_items[pos] is None:
return "" return ""
else: else:
# FIXME: should do something clever/recurive to # FIXME: should do something clever/recurive to
# sort these meaningfully, and return a string: # sort these meaningfully, and return a string:
return str(sort_items[pos]) return str(sort_items[pos])
pos = len(order_by) - 1 pos = len(order_by) - 1
for (field, order) in reversed(order_by): # sort the lasts parts first for (field, order) in reversed(order_by): # sort the lasts parts first
sorted_items.sort(key=lambda item: getitem(item, pos), sorted_items.sort(key=lambda item: getitem(item, pos),
reverse=(order=="DESC")) reverse=(order=="DESC"))
pos -= 1 pos -= 1
# now we will look them up again: # now we will look them up again:
for (order_by_values, handle) in sorted_items: for (order_by_values, handle) in sorted_items:
yield self._tables[class_.__name__]["handle_func"](handle) yield self.get_table_func(class_.__name__,"handle_func")(handle)
def iter_people(self, order_by=None): def iter_people(self, order_by=None):
return self.iter_items(order_by, Person) return self.iter_items(order_by, Person)
@ -1553,9 +1563,9 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
A (possibily) implementation-specific method to get data from A (possibily) implementation-specific method to get data from
db into this database. db into this database.
""" """
for key in db._tables.keys(): for key in db.get_table_func():
cursor = db._tables[key]["cursor_func"] cursor = db.get_table_func(key,"cursor_func")
class_ = db._tables[key]["class_func"] class_ = db.get_table_func(key,"class_func")
for (handle, data) in cursor(): for (handle, data) in cursor():
map = getattr(self, "%s_map" % key.lower()) map = getattr(self, "%s_map" % key.lower())
if isinstance(handle, bytes): if isinstance(handle, bytes):
@ -1578,8 +1588,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
>>> self.get_from_name_and_handle("Person", "a7ad62365bc652387008") >>> self.get_from_name_and_handle("Person", "a7ad62365bc652387008")
>>> self.get_from_name_and_handle("Media", "c3434653675bcd736f23") >>> self.get_from_name_and_handle("Media", "c3434653675bcd736f23")
""" """
if table_name in self._tables: if table_name in self.get_table_func():
return self._tables[table_name]["handle_func"](handle) return self.get_table_func(table_name,"handle_func")(handle)
return None return None
def get_from_name_and_gramps_id(self, table_name, gramps_id): def get_from_name_and_gramps_id(self, table_name, gramps_id):
@ -1593,8 +1603,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
>>> self.get_from_name_and_gramps_id("Family", "F056") >>> self.get_from_name_and_gramps_id("Family", "F056")
>>> self.get_from_name_and_gramps_id("Media", "M00012") >>> self.get_from_name_and_gramps_id("Media", "M00012")
""" """
if table_name in self._tables: if table_name in self.get_table_func():
return self._tables[table_name]["gramps_id_func"](gramps_id) return self.get_table_func(table_name,"gramps_id_func")(gramps_id)
return None return None
def remove_source(self, handle, transaction): def remove_source(self, handle, transaction):
@ -1673,8 +1683,8 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
""" """
Return true if there are no [primary] records in the database Return true if there are no [primary] records in the database
""" """
for table in self._tables: for table in self.get_table_func():
if len(self._tables[table]["handles_func"]()) > 0: if len(self.get_table_func(table,"handles_func")()) > 0:
return False return False
return True return True
@ -1719,8 +1729,9 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
self.set_metadata('place_types', self.place_types) self.set_metadata('place_types', self.place_types)
# Save misc items: # Save misc items:
self.save_surname_list() if self.has_changed:
self.save_gender_stats(self.genderStats) self.save_surname_list()
self.save_gender_stats(self.genderStats)
# Indexes: # Indexes:
self.set_metadata('cmap_index', self.cmap_index) self.set_metadata('cmap_index', self.cmap_index)
@ -1735,6 +1746,7 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
self.close_backend() self.close_backend()
self.db_is_open = False self.db_is_open = False
self._directory = None
def get_bookmarks(self): def get_bookmarks(self):
return self.bookmarks return self.bookmarks
@ -2074,3 +2086,10 @@ class DbGeneric(DbWriteBase, DbReadBase, UpdateCallback, Callback):
def set_default_person_handle(self, handle): def set_default_person_handle(self, handle):
self.set_metadata("default-person-handle", handle) self.set_metadata("default-person-handle", handle)
self.emit('home-person-changed') self.emit('home-person-changed')
def add_table_funcs(self, table, funcs):
"""
Add a new table and funcs to the database.
"""
self.__tables[table] = funcs
setattr(DbGeneric, table, property(lambda self: QuerySet(self, table)))

View File

@ -23,7 +23,7 @@ class HandleClass(str):
super(HandleClass, self).__init__() super(HandleClass, self).__init__()
def join(self, database, handle): def join(self, database, handle):
return database._tables[self.classname]["handle_func"](handle) return database.get_table_func(self.classname,"handle_func")(handle)
@classmethod @classmethod
def get_schema(cls): def get_schema(cls):

View File

@ -297,20 +297,20 @@ class Struct(object):
name, handle = struct["_class"], struct["handle"] name, handle = struct["_class"], struct["handle"]
old_obj = self.db.get_from_name_and_handle(name, handle) old_obj = self.db.get_from_name_and_handle(name, handle)
if old_obj: if old_obj:
commit_func = self.db._tables[name]["commit_func"] commit_func = self.db.get_table_func(name,"commit_func")
commit_func(new_obj, trans) commit_func(new_obj, trans)
else: else:
add_func = self.db._tables[name]["add_func"] add_func = self.db.get_table_func(name,"add_func")
add_func(new_obj, trans) add_func(new_obj, trans)
else: else:
new_obj = Struct.instance_from_struct(struct) new_obj = Struct.instance_from_struct(struct)
name, handle = struct["_class"], struct["handle"] name, handle = struct["_class"], struct["handle"]
old_obj = self.db.get_from_name_and_handle(name, handle) old_obj = self.db.get_from_name_and_handle(name, handle)
if old_obj: if old_obj:
commit_func = self.db._tables[name]["commit_func"] commit_func = self.db.get_table_func(name,"commit_func")
commit_func(new_obj, trans) commit_func(new_obj, trans)
else: else:
add_func = self.db._tables[name]["add_func"] add_func = self.db.get_table_func(name,"add_func")
add_func(new_obj, trans) add_func(new_obj, trans)
def from_struct(self): def from_struct(self):

View File

@ -109,9 +109,9 @@ def generate_case(obj):
#setattr(DatabaseCheck, name, test2) #setattr(DatabaseCheck, name, test2)
db = import_as_dict("example/gramps/example.gramps", User()) db = import_as_dict("example/gramps/example.gramps", User())
for table in db._tables.keys(): for table in db.get_table_func():
for handle in db._tables[table]["handles_func"](): for handle in db.get_table_func(table,"handles_func")():
obj = db._tables[table]["handle_func"](handle) obj = db.get_table_func(table,"handle_func")(handle)
generate_case(obj) generate_case(obj)
class StructTest(unittest.TestCase): class StructTest(unittest.TestCase):

View File

@ -190,14 +190,14 @@ def diff_dbs(db1, db2, user=None):
for item in ['Person', 'Family', 'Source', 'Citation', 'Event', 'Media', for item in ['Person', 'Family', 'Source', 'Citation', 'Event', 'Media',
'Place', 'Repository', 'Note', 'Tag']: 'Place', 'Repository', 'Note', 'Tag']:
step() step()
handles1 = sorted([handle for handle in db1._tables[item]["handles_func"]()]) handles1 = sorted([handle for handle in db1.get_table_func(item,"handles_func")()])
handles2 = sorted([handle for handle in db2._tables[item]["handles_func"]()]) handles2 = sorted([handle for handle in db2.get_table_func(item,"handles_func")()])
p1 = 0 p1 = 0
p2 = 0 p2 = 0
while p1 < len(handles1) and p2 < len(handles2): while p1 < len(handles1) and p2 < len(handles2):
if handles1[p1] == handles2[p2]: # in both if handles1[p1] == handles2[p2]: # in both
item1 = db1._tables[item]["handle_func"](handles1[p1]) item1 = db1.get_table_func(item,"handle_func")(handles1[p1])
item2 = db2._tables[item]["handle_func"](handles2[p2]) item2 = db2.get_tables_func(item,"handle_func")(handles2[p2])
diff = diff_items(item, item1.to_struct(), item2.to_struct()) diff = diff_items(item, item1.to_struct(), item2.to_struct())
if diff: if diff:
diffs += [(item, item1, item2)] diffs += [(item, item1, item2)]
@ -205,19 +205,19 @@ def diff_dbs(db1, db2, user=None):
p1 += 1 p1 += 1
p2 += 1 p2 += 1
elif handles1[p1] < handles2[p2]: # p1 is mssing in p2 elif handles1[p1] < handles2[p2]: # p1 is mssing in p2
item1 = db1._tables[item]["handle_func"](handles1[p1]) item1 = db1.get_table_func(item,"handle_func")(handles1[p1])
missing_from_new += [(item, item1)] missing_from_new += [(item, item1)]
p1 += 1 p1 += 1
elif handles1[p1] > handles2[p2]: # p2 is mssing in p1 elif handles1[p1] > handles2[p2]: # p2 is mssing in p1
item2 = db2._tables[item]["handle_func"](handles2[p2]) item2 = db2.get_table_func(item,"handle_func")(handles2[p2])
missing_from_old += [(item, item2)] missing_from_old += [(item, item2)]
p2 += 1 p2 += 1
while p1 < len(handles1): while p1 < len(handles1):
item1 = db1._tables[item]["handle_func"](handles1[p1]) item1 = db1.get_table_func(item,"handle_func")(handles1[p1])
missing_from_new += [(item, item1)] missing_from_new += [(item, item1)]
p1 += 1 p1 += 1
while p2 < len(handles2): while p2 < len(handles2):
item2 = db2._tables[item]["handle_func"](handles2[p2]) item2 = db2.get_table_func(item,"handle_func")(handles2[p2])
missing_from_old += [(item, item2)] missing_from_old += [(item, item2)]
p2 += 1 p2 += 1
return diffs, missing_from_old, missing_from_new return diffs, missing_from_old, missing_from_new

View File

@ -123,3 +123,4 @@ class EnumeratedListOption(Option):
else: else:
logging.warning(_("Value '%(val)s' not found for option '%(opt)s'") % logging.warning(_("Value '%(val)s' not found for option '%(opt)s'") %
{'val' : str(value), 'opt' : self.get_label()}) {'val' : str(value), 'opt' : self.get_label()})
logging.warning(_("Valid values: ") + str(self.__items))

View File

@ -30,7 +30,10 @@ Proxy class for the Gramps databases. Apply filter
# Gramps libraries # Gramps libraries
# #
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
from gramps.gen.db.base import sort_objects
from .proxybase import ProxyDbBase from .proxybase import ProxyDbBase
from ..lib import (Date, Person, Name, Surname, NameOriginType, Family, Source,
Citation, Event, Media, Place, Repository, Note, Tag)
class FilterProxyDb(ProxyDbBase): class FilterProxyDb(ProxyDbBase):
""" """
@ -70,6 +73,121 @@ class FilterProxyDb(ProxyDbBase):
if person: if person:
self.flist.update(person.get_family_handle_list()) self.flist.update(person.get_family_handle_list())
self.flist.update(person.get_parent_family_handle_list()) self.flist.update(person.get_parent_family_handle_list())
self.__tables = {
'Person':
{
"handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id,
"class_func": Person,
"cursor_func": self.get_person_cursor,
"handles_func": self.get_person_handles,
"iter_func": self.iter_people,
"count_func": self.get_number_of_people,
},
'Family':
{
"handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id,
"class_func": Family,
"cursor_func": self.get_family_cursor,
"handles_func": self.get_family_handles,
"iter_func": self.iter_families,
"count_func": self.get_number_of_families,
},
'Source':
{
"handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id,
"class_func": Source,
"cursor_func": self.get_source_cursor,
"handles_func": self.get_source_handles,
"iter_func": self.iter_sources,
"count_func": self.get_number_of_sources,
},
'Citation':
{
"handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id,
"class_func": Citation,
"cursor_func": self.get_citation_cursor,
"handles_func": self.get_citation_handles,
"iter_func": self.iter_citations,
"count_func": self.get_number_of_citations,
},
'Event':
{
"handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id,
"class_func": Event,
"cursor_func": self.get_event_cursor,
"handles_func": self.get_event_handles,
"iter_func": self.iter_events,
"count_func": self.get_number_of_events,
},
'Media':
{
"handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id,
"class_func": Media,
"cursor_func": self.get_media_cursor,
"handles_func": self.get_media_handles,
"iter_func": self.iter_media,
"count_func": self.get_number_of_media,
},
'Place':
{
"handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id,
"class_func": Place,
"cursor_func": self.get_place_cursor,
"handles_func": self.get_place_handles,
"iter_func": self.iter_places,
"count_func": self.get_number_of_places,
},
'Repository':
{
"handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id,
"class_func": Repository,
"cursor_func": self.get_repository_cursor,
"handles_func": self.get_repository_handles,
"iter_func": self.iter_repositories,
"count_func": self.get_number_of_repositories,
},
'Note':
{
"handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id,
"class_func": Note,
"cursor_func": self.get_note_cursor,
"handles_func": self.get_note_handles,
"iter_func": self.iter_notes,
"count_func": self.get_number_of_notes,
},
'Tag':
{
"handle_func": self.get_tag_from_handle,
"gramps_id_func": None,
"class_func": Tag,
"cursor_func": self.get_tag_cursor,
"handles_func": self.get_tag_handles,
"iter_func": self.iter_tags,
"count_func": self.get_number_of_tags,
}
}
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table].keys()
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def get_person_from_handle(self, handle): def get_person_from_handle(self, handle):
""" """
@ -398,11 +516,14 @@ class FilterProxyDb(ProxyDbBase):
""" """
return self.plist return self.plist
def iter_people(self): def iter_people(self, order_by=None):
""" """
Return an iterator over objects for Persons in the database Return an iterator over objects for Persons in the database
""" """
return map(self.get_person_from_handle, self.plist) if order_by:
return sort_objects(map(self.get_person_from_handle, self.plist), order_by, self)
else:
return map(self.get_person_from_handle, self.plist)
def get_event_handles(self): def get_event_handles(self):
""" """
@ -418,11 +539,14 @@ class FilterProxyDb(ProxyDbBase):
""" """
return self.elist return self.elist
def iter_events(self): def iter_events(self, order_by=None):
""" """
Return an iterator over objects for Events in the database Return an iterator over objects for Events in the database
""" """
return map(self.get_event_from_handle, self.elist) if order_by:
return sort_objects(map(self.get_event_from_handle, self.elist), order_by, self)
else:
return map(self.get_event_from_handle, self.elist)
def get_family_handles(self): def get_family_handles(self):
""" """
@ -438,11 +562,14 @@ class FilterProxyDb(ProxyDbBase):
""" """
return self.flist return self.flist
def iter_families(self): def iter_families(self, order_by=None):
""" """
Return an iterator over objects for Families in the database Return an iterator over objects for Families in the database
""" """
return map(self.get_family_from_handle, self.flist) if order_by:
return sort_objects(map(self.get_family_from_handle, self.flist), order_by, self)
else:
return map(self.get_family_from_handle, self.flist)
def get_note_handles(self): def get_note_handles(self):
""" """
@ -458,11 +585,14 @@ class FilterProxyDb(ProxyDbBase):
""" """
return self.nlist return self.nlist
def iter_notes(self): def iter_notes(self, order_by=None):
""" """
Return an iterator over objects for Notes in the database Return an iterator over objects for Notes in the database
""" """
return map(self.get_note_from_handle, self.nlist) if order_by:
return sort_objects(map(self.get_note_from_handle, self.nlist), order_by, self)
else:
return map(self.get_note_from_handle, self.nlist)
def get_default_person(self): def get_default_person(self):
"""returns the default Person of the database""" """returns the default Person of the database"""

View File

@ -32,6 +32,7 @@ from ..lib import (Date, Person, Name, Surname, NameOriginType, Family, Source,
Citation, Event, Media, Place, Repository, Note, Tag) Citation, Event, Media, Place, Repository, Note, Tag)
from ..utils.alive import probably_alive from ..utils.alive import probably_alive
from ..config import config from ..config import config
from gramps.gen.db.base import sort_objects
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
# #
@ -77,6 +78,121 @@ class LivingProxyDb(ProxyDbBase):
else: else:
self.current_date = None self.current_date = None
self.years_after_death = years_after_death self.years_after_death = years_after_death
self.__tables = {
'Person':
{
"handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id,
"class_func": Person,
"cursor_func": self.get_person_cursor,
"handles_func": self.get_person_handles,
"iter_func": self.iter_people,
"count_func": self.get_number_of_people,
},
'Family':
{
"handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id,
"class_func": Family,
"cursor_func": self.get_family_cursor,
"handles_func": self.get_family_handles,
"iter_func": self.iter_families,
"count_func": self.get_number_of_families,
},
'Source':
{
"handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id,
"class_func": Source,
"cursor_func": self.get_source_cursor,
"handles_func": self.get_source_handles,
"iter_func": self.iter_sources,
"count_func": self.get_number_of_sources,
},
'Citation':
{
"handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id,
"class_func": Citation,
"cursor_func": self.get_citation_cursor,
"handles_func": self.get_citation_handles,
"iter_func": self.iter_citations,
"count_func": self.get_number_of_citations,
},
'Event':
{
"handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id,
"class_func": Event,
"cursor_func": self.get_event_cursor,
"handles_func": self.get_event_handles,
"iter_func": self.iter_events,
"count_func": self.get_number_of_events,
},
'Media':
{
"handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id,
"class_func": Media,
"cursor_func": self.get_media_cursor,
"handles_func": self.get_media_handles,
"iter_func": self.iter_media,
"count_func": self.get_number_of_media,
},
'Place':
{
"handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id,
"class_func": Place,
"cursor_func": self.get_place_cursor,
"handles_func": self.get_place_handles,
"iter_func": self.iter_places,
"count_func": self.get_number_of_places,
},
'Repository':
{
"handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id,
"class_func": Repository,
"cursor_func": self.get_repository_cursor,
"handles_func": self.get_repository_handles,
"iter_func": self.iter_repositories,
"count_func": self.get_number_of_repositories,
},
'Note':
{
"handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id,
"class_func": Note,
"cursor_func": self.get_note_cursor,
"handles_func": self.get_note_handles,
"iter_func": self.iter_notes,
"count_func": self.get_number_of_notes,
},
'Tag':
{
"handle_func": self.get_tag_from_handle,
"gramps_id_func": None,
"class_func": Tag,
"cursor_func": self.get_tag_cursor,
"handles_func": self.get_tag_handles,
"iter_func": self.iter_tags,
"count_func": self.get_number_of_tags,
}
}
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table].keys()
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def get_person_from_handle(self, handle): def get_person_from_handle(self, handle):
""" """
@ -100,18 +216,32 @@ class LivingProxyDb(ProxyDbBase):
family = self.__remove_living_from_family(family) family = self.__remove_living_from_family(family)
return family return family
def iter_people(self): def iter_people(self, order_by=None):
""" """
Protected version of iter_people Protected version of iter_people
""" """
for person in filter(None, self.db.iter_people()): if order_by:
if self.__is_living(person): retval = []
if self.mode == self.MODE_EXCLUDE_ALL: for person in filter(None, self.db.iter_people()):
continue if self.__is_living(person):
if self.mode == self.MODE_EXCLUDE_ALL:
continue
else:
retval.append(self.__restrict_person(person))
else: else:
yield self.__restrict_person(person) retval.append(person)
else: retval = sort_objects(retval, order_by, self)
yield person for item in retval:
yield item
else:
for person in filter(None, self.db.iter_people()):
if self.__is_living(person):
if self.mode == self.MODE_EXCLUDE_ALL:
continue
else:
yield self.__restrict_person(person)
else:
yield person
def get_person_from_gramps_id(self, val): def get_person_from_gramps_id(self, val):
""" """

View File

@ -56,6 +56,121 @@ class PrivateProxyDb(ProxyDbBase):
Create a new PrivateProxyDb instance. Create a new PrivateProxyDb instance.
""" """
ProxyDbBase.__init__(self, db) ProxyDbBase.__init__(self, db)
self.__tables = {
'Person':
{
"handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id,
"class_func": Person,
"cursor_func": self.get_person_cursor,
"handles_func": self.get_person_handles,
"iter_func": self.iter_people,
"count_func": self.get_number_of_people,
},
'Family':
{
"handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id,
"class_func": Family,
"cursor_func": self.get_family_cursor,
"handles_func": self.get_family_handles,
"iter_func": self.iter_families,
"count_func": self.get_number_of_families,
},
'Source':
{
"handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id,
"class_func": Source,
"cursor_func": self.get_source_cursor,
"handles_func": self.get_source_handles,
"iter_func": self.iter_sources,
"count_func": self.get_number_of_sources,
},
'Citation':
{
"handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id,
"class_func": Citation,
"cursor_func": self.get_citation_cursor,
"handles_func": self.get_citation_handles,
"iter_func": self.iter_citations,
"count_func": self.get_number_of_citations,
},
'Event':
{
"handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id,
"class_func": Event,
"cursor_func": self.get_event_cursor,
"handles_func": self.get_event_handles,
"iter_func": self.iter_events,
"count_func": self.get_number_of_events,
},
'Media':
{
"handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id,
"class_func": Media,
"cursor_func": self.get_media_cursor,
"handles_func": self.get_media_handles,
"iter_func": self.iter_media,
"count_func": self.get_number_of_media,
},
'Place':
{
"handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id,
"class_func": Place,
"cursor_func": self.get_place_cursor,
"handles_func": self.get_place_handles,
"iter_func": self.iter_places,
"count_func": self.get_number_of_places,
},
'Repository':
{
"handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id,
"class_func": Repository,
"cursor_func": self.get_repository_cursor,
"handles_func": self.get_repository_handles,
"iter_func": self.iter_repositories,
"count_func": self.get_number_of_repositories,
},
'Note':
{
"handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id,
"class_func": Note,
"cursor_func": self.get_note_cursor,
"handles_func": self.get_note_handles,
"iter_func": self.iter_notes,
"count_func": self.get_number_of_notes,
},
'Tag':
{
"handle_func": self.get_tag_from_handle,
"gramps_id_func": None,
"class_func": Tag,
"cursor_func": self.get_tag_cursor,
"handles_func": self.get_tag_handles,
"iter_func": self.iter_tags,
"count_func": self.get_number_of_tags,
}
}
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table].keys()
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def get_person_from_handle(self, handle): def get_person_from_handle(self, handle):
""" """
@ -285,7 +400,7 @@ class PrivateProxyDb(ProxyDbBase):
""" """
Predicate returning True if object is to be included, else False Predicate returning True if object is to be included, else False
""" """
obj = self.get_unfiltered_object(handle) obj = self.get_unfiltered_media(handle)
return obj and not obj.get_privacy() return obj and not obj.get_privacy()
def include_repository(self, handle): def include_repository(self, handle):

View File

@ -35,7 +35,11 @@ import types
# Gramps libraries # Gramps libraries
# #
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
from ..db.base import DbReadBase, DbWriteBase from ..db.base import DbReadBase, DbWriteBase, sort_objects
from ..lib import (MediaRef, Attribute, Address, EventRef,
Person, Name, Source, RepoRef, Media, Place, Event,
Family, ChildRef, Repository, LdsOrd, Surname, Citation,
SrcAttribute, Note, Tag)
class ProxyCursor(object): class ProxyCursor(object):
""" """
@ -120,6 +124,122 @@ class ProxyDbBase(DbReadBase):
self.note_map = ProxyMap(self, self.get_raw_note_data, self.note_map = ProxyMap(self, self.get_raw_note_data,
self.get_note_handles) self.get_note_handles)
self.__tables = {
'Person':
{
"handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id,
"class_func": Person,
"cursor_func": self.get_person_cursor,
"handles_func": self.get_person_handles,
"iter_func": self.iter_people,
"count_func": self.get_number_of_people,
},
'Family':
{
"handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id,
"class_func": Family,
"cursor_func": self.get_family_cursor,
"handles_func": self.get_family_handles,
"iter_func": self.iter_families,
"count_func": self.get_number_of_families,
},
'Source':
{
"handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id,
"class_func": Source,
"cursor_func": self.get_source_cursor,
"handles_func": self.get_source_handles,
"iter_func": self.iter_sources,
"count_func": self.get_number_of_sources,
},
'Citation':
{
"handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id,
"class_func": Citation,
"cursor_func": self.get_citation_cursor,
"handles_func": self.get_citation_handles,
"iter_func": self.iter_citations,
"count_func": self.get_number_of_citations,
},
'Event':
{
"handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id,
"class_func": Event,
"cursor_func": self.get_event_cursor,
"handles_func": self.get_event_handles,
"iter_func": self.iter_events,
"count_func": self.get_number_of_events,
},
'Media':
{
"handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id,
"class_func": Media,
"cursor_func": self.get_media_cursor,
"handles_func": self.get_media_handles,
"iter_func": self.iter_media,
"count_func": self.get_number_of_media,
},
'Place':
{
"handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id,
"class_func": Place,
"cursor_func": self.get_place_cursor,
"handles_func": self.get_place_handles,
"iter_func": self.iter_places,
"count_func": self.get_number_of_places,
},
'Repository':
{
"handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id,
"class_func": Repository,
"cursor_func": self.get_repository_cursor,
"handles_func": self.get_repository_handles,
"iter_func": self.iter_repositories,
"count_func": self.get_number_of_repositories,
},
'Note':
{
"handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id,
"class_func": Note,
"cursor_func": self.get_note_cursor,
"handles_func": self.get_note_handles,
"iter_func": self.iter_notes,
"count_func": self.get_number_of_notes,
},
'Tag':
{
"handle_func": self.get_tag_from_handle,
"gramps_id_func": None,
"class_func": Tag,
"cursor_func": self.get_tag_cursor,
"handles_func": self.get_tag_handles,
"iter_func": self.iter_tags,
"count_func": self.get_number_of_tags,
}
}
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table].keys()
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def is_open(self): def is_open(self):
""" """
Return 1 if the database has been opened. Return 1 if the database has been opened.
@ -374,73 +494,76 @@ class ProxyDbBase(DbReadBase):
""" """
return filter(self.include_tag, self.db.iter_tag_handles()) return filter(self.include_tag, self.db.iter_tag_handles())
@staticmethod def __iter_object(self, selector, method, order_by=None):
def __iter_object(selector, method):
""" Helper function to return an iterator over an object class """ """ Helper function to return an iterator over an object class """
return filter(lambda obj: ((selector is None) or selector(obj.handle)), retval = filter(lambda obj: ((selector is None) or selector(obj.handle)),
method()) method())
if order_by:
return sort_objects([item for item in retval], order_by, self)
else:
return retval
def iter_people(self): def iter_people(self, order_by=None):
""" """
Return an iterator over Person objects in the database Return an iterator over Person objects in the database
""" """
return self.__iter_object(self.include_person, self.db.iter_people) return self.__iter_object(self.include_person, self.db.iter_people, order_by)
def iter_families(self): def iter_families(self, order_by=None):
""" """
Return an iterator over Family objects in the database Return an iterator over Family objects in the database
""" """
return self.__iter_object(self.include_family, self.db.iter_families) return self.__iter_object(self.include_family, self.db.iter_families, order_by)
def iter_events(self): def iter_events(self, order_by=None):
""" """
Return an iterator over Event objects in the database Return an iterator over Event objects in the database
""" """
return self.__iter_object(self.include_event, self.db.iter_events) return self.__iter_object(self.include_event, self.db.iter_events, order_by)
def iter_places(self): def iter_places(self, order_by=None):
""" """
Return an iterator over Place objects in the database Return an iterator over Place objects in the database
""" """
return self.__iter_object(self.include_place, self.db.iter_places) return self.__iter_object(self.include_place, self.db.iter_places, order_by)
def iter_sources(self): def iter_sources(self, order_by=None):
""" """
Return an iterator over Source objects in the database Return an iterator over Source objects in the database
""" """
return self.__iter_object(self.include_source, self.db.iter_sources) return self.__iter_object(self.include_source, self.db.iter_sources, order_by)
def iter_citations(self): def iter_citations(self, order_by=None):
""" """
Return an iterator over Citation objects in the database Return an iterator over Citation objects in the database
""" """
return self.__iter_object(self.include_citation, self.db.iter_citations) return self.__iter_object(self.include_citation, self.db.iter_citations, order_by)
def iter_media(self): def iter_media(self, order_by=None):
""" """
Return an iterator over Media objects in the database Return an iterator over Media objects in the database
""" """
return self.__iter_object(self.include_media, return self.__iter_object(self.include_media,
self.db.iter_media) self.db.iter_media, order_by)
def iter_repositories(self): def iter_repositories(self, order_by=None):
""" """
Return an iterator over Repositories objects in the database Return an iterator over Repositories objects in the database
""" """
return self.__iter_object(self.include_repository, return self.__iter_object(self.include_repository,
self.db.iter_repositories) self.db.iter_repositories, order_by)
def iter_notes(self): def iter_notes(self, order_by=None):
""" """
Return an iterator over Note objects in the database Return an iterator over Note objects in the database
""" """
return self.__iter_object(self.include_note, self.db.iter_notes) return self.__iter_object(self.include_note, self.db.iter_notes, order_by)
def iter_tags(self): def iter_tags(self, order_by=None):
""" """
Return an iterator over Tag objects in the database Return an iterator over Tag objects in the database
""" """
return self.__iter_object(self.include_tag, self.db.iter_tags) return self.__iter_object(self.include_tag, self.db.iter_tags, order_by)
@staticmethod @staticmethod
def gfilter(predicate, obj): def gfilter(predicate, obj):
@ -468,9 +591,12 @@ class ProxyDbBase(DbReadBase):
return attr return attr
# if a write-method: # if a write-method:
if (name in DbWriteBase.__dict__ and if ((name in DbWriteBase.__dict__ and
not name.startswith("__") and not name.startswith("__") and
type(DbWriteBase.__dict__[name]) is types.FunctionType): type(DbWriteBase.__dict__[name]) is types.FunctionType) or
(name in DbWriteBase.__dict__ and
not name.startswith("__") and
type(DbWriteBase.__dict__[name]) is types.FunctionType)):
raise AttributeError raise AttributeError
# Default behaviour: lookup attribute in parent object # Default behaviour: lookup attribute in parent object
return getattr(self.db, name) return getattr(self.db, name)

View File

@ -71,7 +71,7 @@ class ReferencedBySelectionProxyDb(ProxyDbBase):
# get rid of orphaned people: # get rid of orphaned people:
# first, get all of the links from people: # first, get all of the links from people:
for person in self.db.iter_people(): for person in self.db.iter_people():
self.queue_object("Person", person, False) self.queue_object("Person", person.handle, False)
# save those people: # save those people:
self.restricted_to["Person"] = self.referenced["Person"] self.restricted_to["Person"] = self.referenced["Person"]
# reset, and just follow those people # reset, and just follow those people
@ -83,6 +83,152 @@ class ReferencedBySelectionProxyDb(ProxyDbBase):
obj_type, handle, reference = self.queue.pop() obj_type, handle, reference = self.queue.pop()
self.process_object(obj_type, handle, reference) self.process_object(obj_type, handle, reference)
self.__tables = {
'Person':
{
"handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id,
"class_func": Person,
"cursor_func": self.get_person_cursor,
"handles_func": self.get_person_handles,
"add_func": self.add_person,
"commit_func": self.commit_person,
"iter_func": self.iter_people,
"count_func": self.get_number_of_people,
"del_func": self.remove_person,
},
'Family':
{
"handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id,
"class_func": Family,
"cursor_func": self.get_family_cursor,
"handles_func": self.get_family_handles,
"add_func": self.add_family,
"commit_func": self.commit_family,
"iter_func": self.iter_families,
"count_func": self.get_number_of_families,
"del_func": self.remove_family,
},
'Source':
{
"handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id,
"class_func": Source,
"cursor_func": self.get_source_cursor,
"handles_func": self.get_source_handles,
"add_func": self.add_source,
"commit_func": self.commit_source,
"iter_func": self.iter_sources,
"count_func": self.get_number_of_sources,
"del_func": self.remove_source,
},
'Citation':
{
"handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id,
"class_func": Citation,
"cursor_func": self.get_citation_cursor,
"handles_func": self.get_citation_handles,
"add_func": self.add_citation,
"commit_func": self.commit_citation,
"iter_func": self.iter_citations,
"count_func": self.get_number_of_citations,
"del_func": self.remove_citation,
},
'Event':
{
"handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id,
"class_func": Event,
"cursor_func": self.get_event_cursor,
"handles_func": self.get_event_handles,
"add_func": self.add_event,
"commit_func": self.commit_event,
"iter_func": self.iter_events,
"count_func": self.get_number_of_events,
"del_func": self.remove_event,
},
'Media':
{
"handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id,
"class_func": Media,
"cursor_func": self.get_media_cursor,
"handles_func": self.get_media_handles,
"add_func": self.add_media,
"commit_func": self.commit_media,
"iter_func": self.iter_media,
"count_func": self.get_number_of_media,
"del_func": self.remove_media,
},
'Place':
{
"handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id,
"class_func": Place,
"cursor_func": self.get_place_cursor,
"handles_func": self.get_place_handles,
"add_func": self.add_place,
"commit_func": self.commit_place,
"iter_func": self.iter_places,
"count_func": self.get_number_of_places,
"del_func": self.remove_place,
},
'Repository':
{
"handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id,
"class_func": Repository,
"cursor_func": self.get_repository_cursor,
"handles_func": self.get_repository_handles,
"add_func": self.add_repository,
"commit_func": self.commit_repository,
"iter_func": self.iter_repositories,
"count_func": self.get_number_of_repositories,
"del_func": self.remove_repository,
},
'Note':
{
"handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id,
"class_func": Note,
"cursor_func": self.get_note_cursor,
"handles_func": self.get_note_handles,
"add_func": self.add_note,
"commit_func": self.commit_note,
"iter_func": self.iter_notes,
"count_func": self.get_number_of_notes,
"del_func": self.remove_note,
},
'Tag':
{
"handle_func": self.get_tag_from_handle,
"gramps_id_func": None,
"class_func": Tag,
"cursor_func": self.get_tag_cursor,
"handles_func": self.get_tag_handles,
"add_func": self.add_tag,
"commit_func": self.commit_tag,
"iter_func": self.iter_tags,
"count_func": self.get_number_of_tags,
"del_func": self.remove_tag,
}
}
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table].keys()
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def queue_object(self, obj_type, handle, reference=True): def queue_object(self, obj_type, handle, reference=True):
self.queue.append((obj_type, handle, reference)) self.queue.append((obj_type, handle, reference))

View File

@ -70,7 +70,7 @@ from gramps.gen.lib.nameorigintype import NameOriginType
from gramps.gen.utils.callback import Callback from gramps.gen.utils.callback import Callback
from . import BsddbBaseCursor from . import BsddbBaseCursor
from gramps.gen.db.base import DbReadBase from gramps.gen.db.base import DbReadBase, eval_order_by
from gramps.gen.utils.id import create_id from gramps.gen.utils.id import create_id
from gramps.gen.errors import DbError, HandleError from gramps.gen.errors import DbError, HandleError
from gramps.gen.constfunc import get_env_var from gramps.gen.constfunc import get_env_var
@ -290,96 +290,6 @@ class DbBsddbRead(DbReadBase, Callback):
""" """
DbReadBase.__init__(self) DbReadBase.__init__(self)
Callback.__init__(self) Callback.__init__(self)
self._tables['Person'].update(
{
"handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id,
"class_func": Person,
"cursor_func": self.get_person_cursor,
"handles_func": self.get_person_handles,
"iter_func": self.iter_people,
})
self._tables['Family'].update(
{
"handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id,
"class_func": Family,
"cursor_func": self.get_family_cursor,
"handles_func": self.get_family_handles,
"iter_func": self.iter_families,
})
self._tables['Source'].update(
{
"handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id,
"class_func": Source,
"cursor_func": self.get_source_cursor,
"handles_func": self.get_source_handles,
"iter_func": self.iter_sources,
})
self._tables['Citation'].update(
{
"handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id,
"class_func": Citation,
"cursor_func": self.get_citation_cursor,
"handles_func": self.get_citation_handles,
"iter_func": self.iter_citations,
})
self._tables['Event'].update(
{
"handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id,
"class_func": Event,
"cursor_func": self.get_event_cursor,
"handles_func": self.get_event_handles,
"iter_func": self.iter_events,
})
self._tables['Media'].update(
{
"handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id,
"class_func": Media,
"cursor_func": self.get_media_cursor,
"handles_func": self.get_media_handles,
"iter_func": self.iter_media,
})
self._tables['Place'].update(
{
"handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id,
"class_func": Place,
"cursor_func": self.get_place_cursor,
"handles_func": self.get_place_handles,
"iter_func": self.iter_places,
})
self._tables['Repository'].update(
{
"handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id,
"class_func": Repository,
"cursor_func": self.get_repository_cursor,
"handles_func": self.get_repository_handles,
"iter_func": self.iter_repositories,
})
self._tables['Note'].update(
{
"handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id,
"class_func": Note,
"cursor_func": self.get_note_cursor,
"handles_func": self.get_note_handles,
"iter_func": self.iter_notes,
})
self._tables['Tag'].update(
{
"handle_func": self.get_tag_from_handle,
"gramps_id_func": None,
"class_func": Tag,
"cursor_func": self.get_tag_cursor,
"handles_func": self.get_tag_handles,
"iter_func": self.iter_tags,
})
self.set_person_id_prefix('I%04d') self.set_person_id_prefix('I%04d')
self.set_media_id_prefix('O%04d') self.set_media_id_prefix('O%04d')
@ -474,6 +384,112 @@ class DbBsddbRead(DbReadBase, Callback):
self.txn = None self.txn = None
self.has_changed = False self.has_changed = False
self.__tables = {
'Person':
{
"handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id,
"class_func": Person,
"cursor_func": self.get_person_cursor,
"handles_func": self.get_person_handles,
"iter_func": self.iter_people,
},
'Family':
{
"handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id,
"class_func": Family,
"cursor_func": self.get_family_cursor,
"handles_func": self.get_family_handles,
"iter_func": self.iter_families,
},
'Source':
{
"handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id,
"class_func": Source,
"cursor_func": self.get_source_cursor,
"handles_func": self.get_source_handles,
"iter_func": self.iter_sources,
},
'Citation':
{
"handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id,
"class_func": Citation,
"cursor_func": self.get_citation_cursor,
"handles_func": self.get_citation_handles,
"iter_func": self.iter_citations,
},
'Event':
{
"handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id,
"class_func": Event,
"cursor_func": self.get_event_cursor,
"handles_func": self.get_event_handles,
"iter_func": self.iter_events,
},
'Media':
{
"handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id,
"class_func": Media,
"cursor_func": self.get_media_cursor,
"handles_func": self.get_media_handles,
"iter_func": self.iter_media,
},
'Place':
{
"handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id,
"class_func": Place,
"cursor_func": self.get_place_cursor,
"handles_func": self.get_place_handles,
"iter_func": self.iter_places,
},
'Repository':
{
"handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id,
"class_func": Repository,
"cursor_func": self.get_repository_cursor,
"handles_func": self.get_repository_handles,
"iter_func": self.iter_repositories,
},
'Note':
{
"handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id,
"class_func": Note,
"cursor_func": self.get_note_cursor,
"handles_func": self.get_note_handles,
"iter_func": self.iter_notes,
},
'Tag':
{
"handle_func": self.get_tag_from_handle,
"gramps_id_func": None,
"class_func": Tag,
"cursor_func": self.get_tag_cursor,
"handles_func": self.get_tag_handles,
"iter_func": self.iter_tags,
}
}
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table].keys()
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def set_prefixes(self, person, media, family, source, citation, place, def set_prefixes(self, person, media, family, source, citation, place,
event, repository, note): event, repository, note):
self.set_person_id_prefix(person) self.set_person_id_prefix(person)
@ -493,12 +509,12 @@ class DbBsddbRead(DbReadBase, Callback):
def get_table_names(self): def get_table_names(self):
"""Return a list of valid table names.""" """Return a list of valid table names."""
return list(self._tables.keys()) return list(self.get_table_func())
def get_table_metadata(self, table_name): def get_table_metadata(self, table_name):
"""Return the metadata for a valid table name.""" """Return the metadata for a valid table name."""
if table_name in self._tables: if table_name in self.get_table_func():
return self._tables[table_name] return self.get_table_func(table_name)
return None return None
def get_cursor(self, table, *args, **kwargs): def get_cursor(self, table, *args, **kwargs):
@ -552,12 +568,6 @@ class DbBsddbRead(DbReadBase, Callback):
self.basedb = None self.basedb = None
#remove links to functions #remove links to functions
self.disconnect_all() self.disconnect_all()
for key in self._tables:
for subkey in self._tables[key]:
self._tables[key][subkey] = None
del self._tables[key][subkey]
self._tables[key] = None
del self._tables
## self.bookmarks = None ## self.bookmarks = None
## self.family_bookmarks = None ## self.family_bookmarks = None
## self.event_bookmarks = None ## self.event_bookmarks = None
@ -706,8 +716,8 @@ class DbBsddbRead(DbReadBase, Callback):
>>> self.get_from_name_and_handle("Person", "a7ad62365bc652387008") >>> self.get_from_name_and_handle("Person", "a7ad62365bc652387008")
>>> self.get_from_name_and_handle("Media", "c3434653675bcd736f23") >>> self.get_from_name_and_handle("Media", "c3434653675bcd736f23")
""" """
if table_name in self._tables: if table_name in self.get_table_func():
return self._tables[table_name]["handle_func"](handle) return self.get_table_func(table_name,"handle_func")(handle)
return None return None
def get_from_name_and_gramps_id(self, table_name, gramps_id): def get_from_name_and_gramps_id(self, table_name, gramps_id):
@ -721,8 +731,8 @@ class DbBsddbRead(DbReadBase, Callback):
>>> self.get_from_name_and_gramps_id("Family", "F056") >>> self.get_from_name_and_gramps_id("Family", "F056")
>>> self.get_from_name_and_gramps_id("Media", "M00012") >>> self.get_from_name_and_gramps_id("Media", "M00012")
""" """
if table_name in self._tables: if table_name in self.get_table_func():
return self._tables[table_name]["gramps_id_func"](gramps_id) return self.get_table_func(table_name,"gramps_id_func")(gramps_id)
return None return None
def get_person_from_handle(self, handle): def get_person_from_handle(self, handle):
@ -1233,7 +1243,7 @@ class DbBsddbRead(DbReadBase, Callback):
obj = obj_() obj = obj_()
obj.unserialize(data) obj.unserialize(data)
# just use values and handle to keep small: # just use values and handle to keep small:
sorted_items.append((self.eval_order_by(order_by, obj), obj.handle)) sorted_items.append((eval_order_by(order_by, obj, self), obj.handle))
# next we sort by fields and direction # next we sort by fields and direction
pos = len(order_by) - 1 pos = len(order_by) - 1
for (field, order) in reversed(order_by): # sort the lasts parts first for (field, order) in reversed(order_by): # sort the lasts parts first
@ -1241,7 +1251,7 @@ class DbBsddbRead(DbReadBase, Callback):
pos -= 1 pos -= 1
# now we will look them up again: # now we will look them up again:
for (order_by_values, handle) in sorted_items: for (order_by_values, handle) in sorted_items:
yield self._tables[obj_.__name__]["handle_func"](handle) yield self.get_table_func(obj_.__name__,"handle_func")(handle)
return g return g
# Use closure to define iterators for each primary object type # Use closure to define iterators for each primary object type

View File

@ -240,7 +240,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
DbBsddbRead.__init__(self) DbBsddbRead.__init__(self)
DbWriteBase.__init__(self) DbWriteBase.__init__(self)
#UpdateCallback.__init__(self) #UpdateCallback.__init__(self)
self._tables['Person'].update( self.__tables = {
'Person':
{ {
"handle_func": self.get_person_from_handle, "handle_func": self.get_person_from_handle,
"gramps_id_func": self.get_person_from_gramps_id, "gramps_id_func": self.get_person_from_gramps_id,
@ -252,8 +253,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_people, "count_func": self.get_number_of_people,
"del_func": self.remove_person, "del_func": self.remove_person,
"iter_func": self.iter_people, "iter_func": self.iter_people,
}) },
self._tables['Family'].update( 'Family':
{ {
"handle_func": self.get_family_from_handle, "handle_func": self.get_family_from_handle,
"gramps_id_func": self.get_family_from_gramps_id, "gramps_id_func": self.get_family_from_gramps_id,
@ -265,8 +266,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_families, "count_func": self.get_number_of_families,
"del_func": self.remove_family, "del_func": self.remove_family,
"iter_func": self.iter_families, "iter_func": self.iter_families,
}) },
self._tables['Source'].update( 'Source':
{ {
"handle_func": self.get_source_from_handle, "handle_func": self.get_source_from_handle,
"gramps_id_func": self.get_source_from_gramps_id, "gramps_id_func": self.get_source_from_gramps_id,
@ -278,8 +279,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_sources, "count_func": self.get_number_of_sources,
"del_func": self.remove_source, "del_func": self.remove_source,
"iter_func": self.iter_sources, "iter_func": self.iter_sources,
}) },
self._tables['Citation'].update( 'Citation':
{ {
"handle_func": self.get_citation_from_handle, "handle_func": self.get_citation_from_handle,
"gramps_id_func": self.get_citation_from_gramps_id, "gramps_id_func": self.get_citation_from_gramps_id,
@ -291,8 +292,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_citations, "count_func": self.get_number_of_citations,
"del_func": self.remove_citation, "del_func": self.remove_citation,
"iter_func": self.iter_citations, "iter_func": self.iter_citations,
}) },
self._tables['Event'].update( 'Event':
{ {
"handle_func": self.get_event_from_handle, "handle_func": self.get_event_from_handle,
"gramps_id_func": self.get_event_from_gramps_id, "gramps_id_func": self.get_event_from_gramps_id,
@ -304,8 +305,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_events, "count_func": self.get_number_of_events,
"del_func": self.remove_event, "del_func": self.remove_event,
"iter_func": self.iter_events, "iter_func": self.iter_events,
}) },
self._tables['Media'].update( 'Media':
{ {
"handle_func": self.get_media_from_handle, "handle_func": self.get_media_from_handle,
"gramps_id_func": self.get_media_from_gramps_id, "gramps_id_func": self.get_media_from_gramps_id,
@ -317,8 +318,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_media, "count_func": self.get_number_of_media,
"del_func": self.remove_media, "del_func": self.remove_media,
"iter_func": self.iter_media, "iter_func": self.iter_media,
}) },
self._tables['Place'].update( 'Place':
{ {
"handle_func": self.get_place_from_handle, "handle_func": self.get_place_from_handle,
"gramps_id_func": self.get_place_from_gramps_id, "gramps_id_func": self.get_place_from_gramps_id,
@ -330,8 +331,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_places, "count_func": self.get_number_of_places,
"del_func": self.remove_place, "del_func": self.remove_place,
"iter_func": self.iter_places, "iter_func": self.iter_places,
}) },
self._tables['Repository'].update( 'Repository':
{ {
"handle_func": self.get_repository_from_handle, "handle_func": self.get_repository_from_handle,
"gramps_id_func": self.get_repository_from_gramps_id, "gramps_id_func": self.get_repository_from_gramps_id,
@ -343,8 +344,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_repositories, "count_func": self.get_number_of_repositories,
"del_func": self.remove_repository, "del_func": self.remove_repository,
"iter_func": self.iter_repositories, "iter_func": self.iter_repositories,
}) },
self._tables['Note'].update( 'Note':
{ {
"handle_func": self.get_note_from_handle, "handle_func": self.get_note_from_handle,
"gramps_id_func": self.get_note_from_gramps_id, "gramps_id_func": self.get_note_from_gramps_id,
@ -356,8 +357,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_notes, "count_func": self.get_number_of_notes,
"del_func": self.remove_note, "del_func": self.remove_note,
"iter_func": self.iter_notes, "iter_func": self.iter_notes,
}) },
self._tables['Tag'].update( 'Tag':
{ {
"handle_func": self.get_tag_from_handle, "handle_func": self.get_tag_from_handle,
"gramps_id_func": None, "gramps_id_func": None,
@ -369,7 +370,8 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
"count_func": self.get_number_of_tags, "count_func": self.get_number_of_tags,
"del_func": self.remove_tag, "del_func": self.remove_tag,
"iter_func": self.iter_tags, "iter_func": self.iter_tags,
}) }
}
self.secondary_connected = False self.secondary_connected = False
self.has_changed = False self.has_changed = False
@ -378,6 +380,19 @@ class DbBsddb(DbBsddbRead, DbWriteBase, UpdateCallback):
self.update_python_version = False self.update_python_version = False
self.update_pickle_version = False self.update_pickle_version = False
def get_table_func(self, table=None, func=None):
"""
Private implementation of get_table_func.
"""
if table is None:
return self.__tables.keys()
elif func is None:
return self.__tables[table].keys()
elif func in self.__tables[table].keys():
return self.__tables[table][func]
else:
return super().get_table_func(table, func)
def catch_db_error(func): def catch_db_error(func):
""" """
Decorator function for catching database errors. If *func* throws Decorator function for catching database errors. If *func* throws

View File

@ -31,6 +31,7 @@ import dbapi_support
import time import time
import pickle import pickle
from operator import itemgetter
import logging import logging
LOG = logging.getLogger(".dbapi") LOG = logging.getLogger(".dbapi")
@ -1048,10 +1049,9 @@ class DBAPI(DbGeneric):
self.dbapi.execute(query) self.dbapi.execute(query)
rows = self.dbapi.fetchall() rows = self.dbapi.fetchall()
for row in rows: for row in rows:
obj = obj_() obj = self.get_table_func(class_.__name__,"class_func").create(pickle.loads(row[0]))
obj.unserialize(row[0])
# just use values and handle to keep small: # just use values and handle to keep small:
sorted_items.append((self.eval_order_by(order_by, obj), obj.handle)) sorted_items.append((eval_order_by(order_by, obj, self), obj.handle))
# next we sort by fields and direction # next we sort by fields and direction
pos = len(order_by) - 1 pos = len(order_by) - 1
for (field, order) in reversed(order_by): # sort the lasts parts first for (field, order) in reversed(order_by): # sort the lasts parts first
@ -1059,7 +1059,7 @@ class DBAPI(DbGeneric):
pos -= 1 pos -= 1
# now we will look them up again: # now we will look them up again:
for (order_by_values, handle) in sorted_items: for (order_by_values, handle) in sorted_items:
yield self._tables[obj_.__name__]["handle_func"](handle) yield self.get_table_func(class_.__name__,"handle_func")(handle)
def iter_items(self, order_by, class_): def iter_items(self, order_by, class_):
# check if order_by fields are secondary # check if order_by fields are secondary
@ -1068,7 +1068,7 @@ class DBAPI(DbGeneric):
if order_by: if order_by:
secondary_fields = class_.get_secondary_fields() secondary_fields = class_.get_secondary_fields()
if not self.check_order_by_fields(class_.__name__, order_by, secondary_fields): if not self.check_order_by_fields(class_.__name__, order_by, secondary_fields):
for item in super().iter_items(order_by, class_): for item in self.iter_items_order_by_python(order_by, class_):
yield item yield item
return return
## Continue with dbapi select ## Continue with dbapi select
@ -1578,13 +1578,13 @@ class DBAPI(DbGeneric):
Add secondary fields, update, and create indexes. Add secondary fields, update, and create indexes.
""" """
LOG.info("Rebuilding secondary fields...") LOG.info("Rebuilding secondary fields...")
for table in self._tables.keys(): for table in self.get_table_func():
if not hasattr(self._tables[table]["class_func"], "get_secondary_fields"): if not hasattr(self.get_table_func(table,"class_func"), "get_secondary_fields"):
continue continue
# do a select on all; if it works, then it is ok; else, check them all # do a select on all; if it works, then it is ok; else, check them all
try: try:
fields = [self._hash_name(table, field) for (field, ptype) in fields = [self._hash_name(table, field) for (field, ptype) in
self._tables[table]["class_func"].get_secondary_fields()] self.get_table_func(table,"class_func").get_secondary_fields()]
if fields: if fields:
self.dbapi.execute("select %s from %s limit 1;" % (", ".join(fields), table)) self.dbapi.execute("select %s from %s limit 1;" % (", ".join(fields), table))
# if no error, continue # if no error, continue
@ -1594,7 +1594,7 @@ class DBAPI(DbGeneric):
pass # got to add missing ones, so continue pass # got to add missing ones, so continue
LOG.info("Table %s needs rebuilding..." % table) LOG.info("Table %s needs rebuilding..." % table)
altered = False altered = False
for field_pair in self._tables[table]["class_func"].get_secondary_fields(): for field_pair in self.get_table_func(table,"class_func").get_secondary_fields():
field, python_type = field_pair field, python_type = field_pair
field = self._hash_name(table, field) field = self._hash_name(table, field)
sql_type = self._sql_type(python_type) sql_type = self._sql_type(python_type)
@ -1617,8 +1617,8 @@ class DBAPI(DbGeneric):
""" """
Create the indexes for the secondary fields. Create the indexes for the secondary fields.
""" """
for table in self._tables.keys(): for table in self.get_table_func():
if not hasattr(self._tables[table]["class_func"], "get_index_fields"): if not hasattr(self.get_table_func(table,"class_func"), "get_index_fields"):
continue continue
self.create_secondary_indexes_table(table) self.create_secondary_indexes_table(table)
@ -1626,7 +1626,7 @@ class DBAPI(DbGeneric):
""" """
Create secondary indexes for just this table. Create secondary indexes for just this table.
""" """
for fields in self._tables[table]["class_func"].get_index_fields(): for fields in self.get_table_func(table,"class_func").get_index_fields():
for field in fields: for field in fields:
field = self._hash_name(table, field) field = self._hash_name(table, field)
self.dbapi.try_execute("CREATE INDEX %s_%s ON %s(%s);" % (table, field, table, field)) self.dbapi.try_execute("CREATE INDEX %s_%s ON %s(%s);" % (table, field, table, field))
@ -1636,7 +1636,7 @@ class DBAPI(DbGeneric):
Go through all items in all tables, and update their secondary Go through all items in all tables, and update their secondary
field values. field values.
""" """
for table in self._tables.keys(): for table in self.get_table_func():
self.update_secondary_values_table(table) self.update_secondary_values_table(table)
def update_secondary_values_table(self, table): def update_secondary_values_table(self, table):
@ -1646,9 +1646,9 @@ class DBAPI(DbGeneric):
table - "Person", "Place", "Media", etc. table - "Person", "Place", "Media", etc.
Commits changes. Commits changes.
""" """
if not hasattr(self._tables[table]["class_func"], "get_secondary_fields"): if not hasattr(self.get_table_func(table,"class_func"), "get_secondary_fields"):
return return
for item in self._tables[table]["iter_func"](): for item in self.get_table_func(table,"iter_func")():
self.update_secondary_values(item) self.update_secondary_values(item)
self.dbapi.commit() self.dbapi.commit()
@ -1658,7 +1658,7 @@ class DBAPI(DbGeneric):
Does not commit. Does not commit.
""" """
table = item.__class__.__name__ table = item.__class__.__name__
fields = self._tables[table]["class_func"].get_secondary_fields() fields = self.get_table_func(table,"class_func").get_secondary_fields()
fields = [field for (field, direction) in fields] fields = [field for (field, direction) in fields]
sets = [] sets = []
values = [] values = []
@ -1671,7 +1671,7 @@ class DBAPI(DbGeneric):
self.dbapi.execute("UPDATE %s SET %s where handle = ?;" % (table, ", ".join(sets)), self.dbapi.execute("UPDATE %s SET %s where handle = ?;" % (table, ", ".join(sets)),
values + [item.handle]) values + [item.handle])
def sql_repr(self, value): def _sql_repr(self, value):
""" """
Given a Python value, turn it into a SQL value. Given a Python value, turn it into a SQL value.
""" """
@ -1684,7 +1684,7 @@ class DBAPI(DbGeneric):
else: else:
return repr(value) return repr(value)
def build_where_clause_recursive(self, table, where): def _build_where_clause_recursive(self, table, where):
""" """
where - (field, op, value) where - (field, op, value)
- ["NOT", where] - ["NOT", where]
@ -1695,26 +1695,26 @@ class DBAPI(DbGeneric):
return "" return ""
elif len(where) == 3: elif len(where) == 3:
field, op, value = where field, op, value = where
return "(%s %s %s)" % (self._hash_name(table, field), op, self.sql_repr(value)) return "(%s %s %s)" % (self._hash_name(table, field), op, self._sql_repr(value))
elif where[0] in ["AND", "OR"]: elif where[0] in ["AND", "OR"]:
parts = [self.build_where_clause_recursive(table, part) parts = [self._build_where_clause_recursive(table, part)
for part in where[1]] for part in where[1]]
return "(%s)" % ((" %s " % where[0]).join(parts)) return "(%s)" % ((" %s " % where[0]).join(parts))
else: else:
return "(NOT %s)" % self.build_where_clause_recursive(table, where[1]) return "(NOT %s)" % self._build_where_clause_recursive(table, where[1])
def build_where_clause(self, table, where): def _build_where_clause(self, table, where):
""" """
where - a list in where format where - a list in where format
return - "WHERE conditions..." return - "WHERE conditions..."
""" """
parts = self.build_where_clause_recursive(table, where) parts = self._build_where_clause_recursive(table, where)
if parts: if parts:
return ("WHERE " + parts) return ("WHERE " + parts)
else: else:
return "" return ""
def build_order_clause(self, table, order_by): def _build_order_clause(self, table, order_by):
""" """
order_by - [(field, "ASC" | "DESC"), ...] order_by - [(field, "ASC" | "DESC"), ...]
""" """
@ -1725,7 +1725,7 @@ class DBAPI(DbGeneric):
else: else:
return "" return ""
def build_select_fields(self, table, select_fields, secondary_fields): def _build_select_fields(self, table, select_fields, secondary_fields):
""" """
fields - [field, ...] fields - [field, ...]
return: "field, field, field" return: "field, field, field"
@ -1736,7 +1736,7 @@ class DBAPI(DbGeneric):
else: else:
return ["blob_data"] # nope, we'll have to expand blob to get all fields return ["blob_data"] # nope, we'll have to expand blob to get all fields
def check_order_by_fields(self, table, order_by, secondary_fields): def _check_order_by_fields(self, table, order_by, secondary_fields):
""" """
Check to make sure all order_by fields are defined. If not, then Check to make sure all order_by fields are defined. If not, then
we need to do the Python-based order. we need to do the Python-based order.
@ -1749,7 +1749,7 @@ class DBAPI(DbGeneric):
return False return False
return True return True
def check_where_fields(self, table, where, secondary_fields): def _check_where_fields(self, table, where, secondary_fields):
""" """
Check to make sure all where fields are defined. If not, then Check to make sure all where fields are defined. If not, then
we need to do the Python-based select. we need to do the Python-based select.
@ -1762,19 +1762,19 @@ class DBAPI(DbGeneric):
connector, exprs = where connector, exprs = where
if connector in ["AND", "OR"]: if connector in ["AND", "OR"]:
for expr in exprs: for expr in exprs:
value = self.check_where_fields(table, expr, secondary_fields) value = self._check_where_fields(table, expr, secondary_fields)
if value == False: if value == False:
return False return False
return True return True
else: # "NOT" else: # "NOT"
return self.check_where_fields(table, exprs, secondary_fields) return self._check_where_fields(table, exprs, secondary_fields)
elif len(where) == 3: # (name, op, value) elif len(where) == 3: # (name, op, value)
(name, op, value) = where (name, op, value) = where
# just the ones we need for where # just the ones we need for where
return (self._hash_name(table, name) in secondary_fields) return (self._hash_name(table, name) in secondary_fields)
def select(self, table, fields=None, start=0, limit=-1, def _select(self, table, fields=None, start=0, limit=-1,
where=None, order_by=None): where=None, order_by=None):
""" """
Default implementation of a select for those databases Default implementation of a select for those databases
that don't support SQL. Returns a list of dicts, total, that don't support SQL. Returns a list of dicts, total,
@ -1790,65 +1790,65 @@ class DBAPI(DbGeneric):
["NOT", where] ["NOT", where]
order_by - [[fieldname, "ASC" | "DESC"], ...] order_by - [[fieldname, "ASC" | "DESC"], ...]
""" """
hashed_fields = [self._hash_name(table, field) for field in fields]
secondary_fields = ([self._hash_name(table, field) for (field, ptype) in secondary_fields = ([self._hash_name(table, field) for (field, ptype) in
self._tables[table]["class_func"].get_secondary_fields()] + self.get_table_func(table,"class_func").get_secondary_fields()] +
["handle"]) # handle is a sql field, but not listed in secondaries ["handle"]) # handle is a sql field, but not listed in secondaries
if ((not self.check_where_fields(table, where, secondary_fields)) or # If no fields, then we need objects:
(not self.check_order_by_fields(table, order_by, secondary_fields))): # Check to see if where matches SQL fields:
return super().select(table, fields, start, limit, where, order_by) if ((not self._check_where_fields(table, where, secondary_fields)) or
fields = hashed_fields (not self._check_order_by_fields(table, order_by, secondary_fields))):
start_time = time.time() # If not, then need to do select via Python:
where_clause = self.build_where_clause(table, where) generator = super()._select(table, fields, start, limit, where, order_by)
order_clause = self.build_order_clause(table, order_by) for item in generator:
select_fields = self.build_select_fields(table, fields, secondary_fields) yield item
# Get the total count: return
if limit != -1 or start != 0: # get the total that would match: # Otherwise, we are SQL
self.dbapi.execute("select count(1) from %s %s;" % (table, where_clause)) if fields is None:
total = self.dbapi.fetchone()[0] fields = ["blob_data"]
get_count_only = False
if fields[0] == "count(1)":
hashed_fields = ["count(1)"]
fields = ["count(1)"]
select_fields = ["count(1)"]
get_count_only = True
else: else:
total = None # need to get later hashed_fields = [self._hash_name(table, field) for field in fields]
class Result(list): fields = hashed_fields
""" select_fields = self._build_select_fields(table, fields, secondary_fields)
A list rows of just matching for this page, with total = all, where_clause = self._build_where_clause(table, where)
time = time to select, query = the SQL query, and expanded order_clause = self._build_order_clause(table, order_by)
if unpickled. if get_count_only:
""" select_fields = ["1"]
total = 0
time = 0.0
query = None
expanded = False
result = Result()
if start: if start:
query = "SELECT %s FROM %s %s %s LIMIT %s, %s;" % ( query = "SELECT %s FROM %s %s %s LIMIT %s, %s" % (
", ".join(select_fields), table, where_clause, order_clause, start, limit ", ".join(select_fields), table, where_clause, order_clause, start, limit
) )
else: else:
query = "SELECT %s FROM %s %s %s LIMIT %s;" % ( query = "SELECT %s FROM %s %s %s LIMIT %s" % (
", ".join(select_fields), table, where_clause, order_clause, limit ", ".join(select_fields), table, where_clause, order_clause, limit
) )
if get_count_only:
self.dbapi.execute("SELECT count(1) from (%s);" % query)
rows = self.dbapi.fetchall()
yield rows[0][0]
return
self.dbapi.execute(query) self.dbapi.execute(query)
rows = self.dbapi.fetchall() rows = self.dbapi.fetchall()
if total is None:
total = len(rows)
expanded = False
for row in rows: for row in rows:
obj = None # don't build it if you don't need it if fields[0] != "blob_data":
data = {} obj = None # don't build it if you don't need it
for field in fields: data = {}
if field in select_fields: for field in fields:
data[field.replace("__", ".")] = row[select_fields.index(field)] if field in select_fields:
else: data[field.replace("__", ".")] = row[select_fields.index(field)]
if obj is None: # we need it! create it and cache it: else:
obj = self._tables[table]["class_func"].create(pickle.loads(row[0])) if obj is None: # we need it! create it and cache it:
expanded = True obj = self.get_table_func(table,"class_func").create(pickle.loads(row[0]))
# get the field, even if we need to do a join: # get the field, even if we need to do a join:
# FIXME: possible optimize: do a join in select for this if needed: # FIXME: possible optimize: do a join in select for this if needed:
field = field.replace("__", ".") field = field.replace("__", ".")
data[field] = obj.get_field(field, self, ignore_errors=True) data[field] = obj.get_field(field, self, ignore_errors=True)
result.append(data) yield data
result.total = total else:
result.time = time.time() - start_time obj = self.get_table_func(table,"class_func").create(pickle.loads(row[0]))
result.query = query yield obj
result.expanded = expanded
return result

View File

@ -1,14 +1,17 @@
import os import os
import sqlite3 import sqlite3
import logging
sqlite3.paramstyle = 'qmark' sqlite3.paramstyle = 'qmark'
class Sqlite(object): class Sqlite(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.log = logging.getLogger(".sqlite")
self.connection = sqlite3.connect(*args, **kwargs) self.connection = sqlite3.connect(*args, **kwargs)
self.queries = {} self.queries = {}
def execute(self, *args, **kwargs): def execute(self, *args, **kwargs):
self.log.debug(args)
self.cursor = self.connection.execute(*args, **kwargs) self.cursor = self.connection.execute(*args, **kwargs)
def fetchone(self): def fetchone(self):

View File

@ -79,42 +79,170 @@ class BSDDBTest(unittest.TestCase):
self.assertTrue(all([isinstance(r, EventRef) for r in result]), result) self.assertTrue(all([isinstance(r, EventRef) for r in result]), result)
def test_select_1(self): def test_select_1(self):
result = self.db.select("Person", ["gramps_id"]) result = list(self.db._select("Person", ["gramps_id"]))
self.assertTrue(len(result) == 60, len(result)) self.assertTrue(len(result) == 60, len(result))
def test_select_2(self): def test_select_2(self):
result = self.db.select("Person", ["gramps_id"], result = list(self.db._select("Person", ["gramps_id"],
where=("gramps_id", "LIKE", "I000%")) where=("gramps_id", "LIKE", "I000%")))
self.assertTrue(len(result) == 10, len(result)) self.assertTrue(len(result) == 10, len(result))
def test_select_3(self): def test_select_3(self):
result = self.db.select("Family", ["mother_handle.gramps_id"], result = list(self.db._select("Family", ["mother_handle.gramps_id"],
where=("mother_handle.gramps_id", "LIKE", "I003%")) where=("mother_handle.gramps_id", "LIKE", "I003%")))
self.assertTrue(len(result) == 6, result) self.assertTrue(len(result) == 6, result)
def test_select_4(self): def test_select_4(self):
result = self.db.select("Family", ["mother_handle.event_ref_list.ref.gramps_id"]) result = list(self.db._select("Family",
["mother_handle.event_ref_list.ref.gramps_id"]))
self.assertTrue(len(result) == 23, len(result)) self.assertTrue(len(result) == 23, len(result))
def test_select_4(self):
result = self.db.select("Family", ["mother_handle.event_ref_list.ref.gramps_id"],
where=("mother_handle.event_ref_list.ref.gramps_id", "=", 'E0156'))
self.assertTrue(len(result) == 1, len(result))
def test_select_5(self): def test_select_5(self):
result = self.db.select("Family", ["mother_handle.event_ref_list.ref.self.gramps_id"]) result = list(self.db._select("Family",
["mother_handle.event_ref_list.ref.self.gramps_id"]))
self.assertTrue(len(result) == 23, len(result)) self.assertTrue(len(result) == 23, len(result))
def test_select_6(self): def test_select_6(self):
result = self.db.select("Family", ["mother_handle.event_ref_list.0"]) result = list(self.db._select("Family", ["mother_handle.event_ref_list.0"]))
self.assertTrue(all([isinstance(r["mother_handle.event_ref_list.0"], (EventRef, type(None))) for r in result]), self.assertTrue(all([isinstance(r["mother_handle.event_ref_list.0"],
(EventRef, type(None))) for r in result]),
[r["mother_handle.event_ref_list.0"] for r in result]) [r["mother_handle.event_ref_list.0"] for r in result])
def test_select_7(self): def test_select_7(self):
result = self.db.select("Family", ["mother_handle.event_ref_list.0"], result = list(self.db._select("Family", ["mother_handle.event_ref_list.0"],
where=("mother_handle.event_ref_list.0", "!=", None)) where=("mother_handle.event_ref_list.0", "!=", None)))
self.assertTrue(len(result) == 21, len(result)) self.assertTrue(len(result) == 21, len(result))
def test_select_8(self):
result = list(self.db._select("Family", ["mother_handle.event_ref_list.ref.gramps_id"],
where=("mother_handle.event_ref_list.ref.gramps_id", "=", 'E0156')))
self.assertTrue(len(result) == 1, len(result))
def test_queryset_1(self):
result = list(self.db.Person.select())
self.assertTrue(len(result) == 60, len(result))
def test_queryset_2(self):
result = list(self.db.Person.filter(gramps_id__LIKE="I000%").select())
self.assertTrue(len(result) == 10, len(result))
def test_queryset_3(self):
result = list(self.db.Family
.filter(mother_handle__gramps_id__LIKE="I003%")
.select())
self.assertTrue(len(result) == 6, result)
def test_queryset_4(self):
result = list(self.db.Family.select())
self.assertTrue(len(result) == 23, len(result))
def test_queryset_4(self):
result = list(self.db.Family
.filter(mother_handle__event_ref_list__ref__gramps_id='E0156')
.select())
self.assertTrue(len(result) == 1, len(result))
def test_queryset_5(self):
result = list(self.db.Family
.select("mother_handle.event_ref_list.ref.self.gramps_id"))
self.assertTrue(len(result) == 23, len(result))
def test_queryset_6(self):
result = list(self.db.Family.select("mother_handle.event_ref_list.0"))
self.assertTrue(all([isinstance(r["mother_handle.event_ref_list.0"],
(EventRef, type(None))) for r in result]),
[r["mother_handle.event_ref_list.0"] for r in result])
def test_queryset_7(self):
from gramps.gen.db import NOT
result = list(self.db.Family
.filter(NOT(mother_handle__event_ref_list__0=None))
.select())
self.assertTrue(len(result) == 21, len(result))
def test_order_1(self):
result = list(self.db.Person.order("gramps_id").select())
self.assertTrue(len(result) == 60, len(result))
def test_order_2(self):
result = list(self.db.Person.order("-gramps_id").select())
self.assertTrue(len(result) == 60, len(result))
def test_proxy_1(self):
result = list(self.db.Person.proxy("living", False).select())
self.assertTrue(len(result) == 31, len(result))
def test_proxy_2(self):
result = list(self.db.Person.proxy("living", True).select())
self.assertTrue(len(result) == 60, len(result))
def test_proxy_3(self):
result = len(list(self.db.Person
.proxy("private")
.order("-gramps_id")
.select("gramps_id")))
self.assertTrue(result == 59, result)
def test_map_1(self):
result = sum(list(self.db.Person.map(lambda p: 1).select()))
self.assertTrue(result == 60, result)
def test_tag_1(self):
self.db.Person.filter(gramps_id="I0001").tag("Test")
result = self.db.Person.filter(tag_list__name="Test").count()
self.assertTrue(result == 1, result)
# def test_filter_1(self):
# from gramps.gen.filters.rules.person import (IsDescendantOf,
# IsAncestorOf)
# from gramps.gen.filters import GenericFilter
# filter = GenericFilter()
# filter.set_logical_op("or")
# filter.add_rule(IsDescendantOf([self.db.get_default_person().gramps_id,
# True]))
# filter.add_rule(IsAncestorOf([self.db.get_default_person().gramps_id,
# True]))
# result = self.db.Person.filter(filter).count()
# self.assertTrue(result == 15, result)
def test_filter_2(self):
result = self.db.Person.filter(lambda p: p.private).count()
self.assertTrue(result == 1, result)
def test_filter_3(self):
result = self.db.Person.filter(lambda p: not p.private).count()
self.assertTrue(result == 59, result)
def test_limit_1(self):
result = self.db.Person.limit(count=50).count()
self.assertTrue(result == 50, result)
def test_limit_2(self):
result = self.db.Person.limit(start=50, count=50).count()
self.assertTrue(result == 10, result)
def test_ordering_1(self):
worked = None
try:
result = list(self.db.Person
.filter(lambda p: p.private)
.order("private")
.select())
worked = True
except:
worked = False
self.assertTrue(not worked, "should have failed")
def test_ordering_2(self):
worked = None
try:
result = list(self.db.Person.order("private")
.filter(lambda p: p.private)
.select())
worked = True
except:
worked = False
self.assertTrue(worked, "should have worked")
class DBAPITest(BSDDBTest): class DBAPITest(BSDDBTest):
dbwrap = DBAPI() dbwrap = DBAPI()

View File

@ -202,8 +202,9 @@ class Test4(U.TestCase):
logging.error(emsg) logging.error(emsg)
ll = tl.logfile_getlines() ll = tl.logfile_getlines()
nl = len(ll) nl = len(ll)
self.assertEquals(nl,3, print(repr(ll))
tu.msg(nl,3, "pass %d: expected line count" % i)) self.assertEquals(nl,2,
tu.msg(nl,2, "pass %d: expected line count" % i))
#del tl #del tl