Remove schema field functionality

This commit is contained in:
Nick Hall 2017-02-16 23:33:55 +00:00
parent 6f0119288b
commit 62403b5789
4 changed files with 1 additions and 301 deletions

View File

@ -1419,14 +1419,6 @@ class DbReadBase:
"""
raise NotImplementedError
def _hash_name(self, table, name):
"""
Used in SQL functions to eval expressions involving selected
data.
"""
name = self.get_table_func(table, "class_func").get_field_alias(name)
return name.replace(".", "__")
class DbWriteBase(DbReadBase):
"""

View File

@ -155,21 +155,6 @@ class TableObject(BaseObject):
"""
return {}
@classmethod
def field_aliases(cls):
"""
Return dictionary of alias to full field names
for this object class.
"""
return {}
@classmethod
def get_field_alias(cls, alias):
"""
Return full field name for an alias, if one.
"""
return cls.field_aliases().get(alias, alias)
@classmethod
def get_schema(cls):
"""
@ -177,21 +162,6 @@ class TableObject(BaseObject):
"""
return {}
@classmethod
def get_extra_secondary_fields(cls):
"""
Return a list of full field names and types for secondary
fields that are not directly listed in the schema.
"""
return []
@classmethod
def get_index_fields(cls):
"""
Return a list of full field names for indices.
"""
return []
@classmethod
def get_secondary_fields(cls):
"""
@ -210,193 +180,3 @@ class TableObject(BaseObject):
schema_type,
value.get("maxLength")))
return result
@classmethod
def get_label(cls, field, _):
"""
Get the associated label given a field name of this object.
No index positions allowed on lists.
"""
chain = field.split(".")
path = cls._follow_schema_path(chain[:-1])
labels = path.get_labels(_)
if chain[-1] in labels:
return labels[chain[-1]]
else:
raise Exception("%s has no such label on %s: '%s'" %
(cls, path, field))
@classmethod
def get_field_type(cls, field):
"""
Get the associated label given a field name of this object.
No index positions allowed on lists.
"""
field = cls.get_field_alias(field)
chain = field.split(".")
ftype = cls._follow_schema_path(chain)
return ftype
@classmethod
def _follow_schema_path(cls, chain):
"""
Follow a list of schema items. Return endpoint.
"""
path = cls
for part in chain:
schema = path.get_schema()
if part.isdigit():
pass # skip over
elif part in schema.keys():
path = schema[part]
else:
raise Exception("No such field. Valid fields are: %s" % list(schema.keys()))
if isinstance(path, (list, tuple)):
path = path[0]
return path
def get_field(self, field, db=None, ignore_errors=False):
"""
Get the value of a field.
"""
field = self.__class__.get_field_alias(field)
chain = field.split(".")
path = self._follow_field_path(chain, db, ignore_errors)
return path
def _follow_field_path(self, chain, db=None, ignore_errors=False):
"""
Follow a list of items. Return endpoint(s) only.
With the db argument, can do joins across tables.
self - current object
returns - None, endpoint, of recursive list of endpoints
"""
from .handle import HandleClass
# start with [self, self, chain, path_to=[]]
# results = []
# expand when you reach multiple answers [obj, chain_left, []]
# if you get to an endpoint, put results
# go until nothing left to expand
todo = [(self, self, [], chain)]
results = []
while todo:
parent, current, path_to, chain = todo.pop()
#print("expand:", parent.__class__.__name__,
# current.__class__.__name__,
# path_to,
# chain)
keep_going = True
p = 0
while p < len(chain) and keep_going:
#print("while:", path_to, chain[p:])
part = chain[p]
if hasattr(current, part): # attribute
current = getattr(current, part)
path_to.append(part)
# need to consider current+part if current is list:
elif isinstance(current, (list, tuple)):
if part.isdigit():
# followed by index, so continue here
if int(part) < len(current):
current = current[int(part)]
path_to.append(part)
elif ignore_errors:
current = None
keeping_going = False
else:
raise Exception("invalid index position")
else: # else branch! in middle, split paths
for i in range(len(current)):
#print("split list:", self.__class__.__name__,
# current.__class__.__name__,
# path_to[:],
# [str(i)] + chain[p:])
todo.append([self, current, path_to[:], [str(i)] + chain[p:]])
current = None
keep_going = False
else: # part not found on this self
# current is a handle
# part is something on joined object
if parent:
ptype = parent.__class__.get_field_type(".".join(path_to))
if isinstance(ptype, HandleClass):
if db:
# start over here:
obj = None
if current:
try:
obj = ptype.join(db, current)
except HandleError:
if ignore_errors:
obj = None
else:
raise
if part == "self":
current = obj
path_to = []
#print("split self:", obj.__class__.__name__,
# current.__class__.__name__,
# path_to,
# chain[p + 1:])
todo.append([obj, current, path_to, chain[p + 1:]])
elif obj:
current = getattr(obj, part)
#print("split :", obj.__class__.__name__,
# current.__class__.__name__,
# [part],
# chain[p + 1:])
todo.append([obj, current, [part], chain[p + 1:]])
current = None
keep_going = False
else:
raise Exception("Can't join without database")
elif part == "self":
pass
elif ignore_errors:
pass
else:
raise Exception("%s is not a valid field of %s; use %s" %
(part, current, dir(current)))
current = None
keep_going = False
p += 1
if keep_going:
results.append(current)
if len(results) == 1:
return results[0]
elif len(results) == 0:
return None
else:
return results
def set_field(self, field, value, db=None, ignore_errors=False):
"""
Set the value of a basic field (str, int, float, or bool).
value can be a string or actual value.
Returns number of items changed.
"""
field = self.__class__.get_field_alias(field)
chain = field.split(".")
path = self._follow_field_path(chain[:-1], db, ignore_errors)
ftype = self.get_field_type(field)
# ftype is str, bool, float, or int
value = (value in ['True', True]) if ftype is bool else value
return self._set_fields(path, chain[-1], value, ftype)
def _set_fields(self, path, attr, value, ftype):
"""
Helper function to handle recursive lists of items.
"""
from .handle import HandleClass
if isinstance(path, (list, tuple)):
count = 0
for item in path:
count += self._set_fields(item, attr, value, ftype)
elif isinstance(ftype, HandleClass):
setattr(path, attr, value)
count = 1
else:
setattr(path, attr, ftype(value))
count = 1
return count

View File

@ -1,69 +0,0 @@
#
# Gramps - a GTK+/GNOME based genealogy program
#
# Copyright (C) 2016 Gramps Development Team
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
""" Tests for using database fields """
import unittest
from gramps.gen.db import DbTxn
from gramps.gen.db.utils import make_database
from ..import (Person, Surname, Name, NameType, Family, FamilyRelType,
Event, EventType, Source, Place, PlaceName, Citation, Date,
Repository, RepositoryType, Media, Note, NoteType,
StyledText, StyledTextTag, StyledTextTagType, Tag,
ChildRef, ChildRefType, Attribute, MediaRef, AttributeType,
Url, UrlType, Address, EventRef, EventRoleType, RepoRef,
FamilyRelType, LdsOrd, MediaRef, PersonRef, PlaceType,
SrcAttribute, SrcAttributeType)
class FieldBaseTest(unittest.TestCase):
def setUp(self):
db = make_database("inmemorydb")
db.load(None)
with DbTxn("Test", db) as trans:
# Add some people:
person1 = Person()
person1.primary_name = Name()
person1.primary_name.surname_list.append(Surname())
person1.primary_name.surname_list[0].surname = "Smith"
person1.gramps_id = "I0001"
db.add_person(person1, trans) # person gets a handle
# Add some families:
family1 = Family()
family1.father_handle = person1.handle
family1.gramps_id = "F0001"
db.add_family(family1, trans)
self.db = db
def test_field_access01(self):
person = self.db.get_person_from_gramps_id("I0001")
self.assertEqual(person.get_field("primary_name.surname_list.0.surname"),
"Smith")
def test_field_join01(self):
family = self.db.get_family_from_gramps_id("F0001")
self.assertEqual(family.get_field("father_handle.primary_name.surname_list.0.surname", self.db),
"Smith")
if __name__ == "__main__":
unittest.main()

View File

@ -920,7 +920,6 @@ class DBAPI(DbGeneric):
table_name = table.lower()
for field, schema_type, max_length in self.get_table_func(
table, "class_func").get_secondary_fields():
field = self._hash_name(table, field)
sql_type = self._sql_type(schema_type, max_length)
try:
# test to see if it exists:
@ -947,10 +946,8 @@ class DBAPI(DbGeneric):
sets = []
values = []
for field in fields:
value = obj.get_field(field, self, ignore_errors=True)
field = self._hash_name(obj.__class__.__name__, field)
sets.append("%s = ?" % field)
values.append(value)
values.append(getattr(obj, field))
# Derived fields
if table == 'Person':