QuerySet tag(): remove batch; allow tag removal; test

This commit is contained in:
Doug Blank 2016-05-06 08:14:54 -04:00
parent 1fa604645c
commit 61c2ed3240
2 changed files with 20 additions and 4 deletions

View File

@ -2306,9 +2306,9 @@ class QuerySet(object):
self.generator = generator() self.generator = generator()
return self return self
def tag(self, tag_text): def tag(self, tag_text, remove=False):
""" """
Tag the selected items with the tag name. Tag or untag the selected items with the tag name.
""" """
if self.generator and self.needs_to_run: if self.generator and self.needs_to_run:
raise Exception("Queries in invalid order") raise Exception("Queries in invalid order")
@ -2317,8 +2317,12 @@ class QuerySet(object):
else: else:
self.generator = self._generate() self.generator = self._generate()
tag = self.database.get_tag_from_name(tag_text) tag = self.database.get_tag_from_name(tag_text)
if (not tag and remove):
# no tag by this name, and want to remove it
# nothing to do
return
trans_class = self.database.get_transaction_class() trans_class = self.database.get_transaction_class()
with trans_class("Tag Selected Items", self.database, batch=True) as trans: with trans_class("Tag Selected Items", self.database, batch=False) as trans:
if tag is None: if tag is None:
tag = self.database.get_table_func("Tag","class_func")() tag = self.database.get_table_func("Tag","class_func")()
tag.set_name(tag_text) tag.set_name(tag_text)
@ -2327,5 +2331,9 @@ class QuerySet(object):
for item in self.generator: for item in self.generator:
if tag.handle not in item.tag_list: if tag.handle not in item.tag_list:
item.add_tag(tag.handle) item.add_tag(tag.handle)
commit_func(item, trans) elif remove:
item.remove_tag(tag.handle)
else:
continue
commit_func(item, trans)

View File

@ -191,6 +191,14 @@ class BSDDBTest(unittest.TestCase):
result = self.db.Person.where(lambda person: person.tag_list.name == "Test").count() result = self.db.Person.where(lambda person: person.tag_list.name == "Test").count()
self.assertTrue(result == 1, result) self.assertTrue(result == 1, result)
def test_tag_2(self):
self.db.Person.where(lambda person: person.gramps_id == "I0001").tag("Test")
result = self.db.Person.where(lambda person: person.tag_list.name == "Test").count()
self.assertTrue(result == 1, result)
self.db.Person.where(lambda person: person.gramps_id == "I0001").tag("Test", remove=True)
result = self.db.Person.where(lambda person: person.tag_list.name == "Test").count()
self.assertTrue(result == 0, result)
def test_filter_1(self): def test_filter_1(self):
from gramps.gen.filters.rules.person import (IsDescendantOf, from gramps.gen.filters.rules.person import (IsDescendantOf,
IsAncestorOf) IsAncestorOf)