Clean up implementation of document classes

pull/58/head
Michael X. Grey 2018-05-10 12:55:19 -07:00
parent f5a7fe77d1
commit be675f1b1c
1 changed files with 89 additions and 24 deletions

View File

@ -55,6 +55,7 @@ import collections
import xml.etree.cElementTree as etree import xml.etree.cElementTree as etree
import xml.etree.ElementTree.Element as Element import xml.etree.ElementTree.Element as Element
import xml.etree.ElementTree.SubElement as SubElement import xml.etree.ElementTree.SubElement as SubElement
import warnings
# Internal dependencies # Internal dependencies
from .parser import parse_path from .parser import parse_path
@ -79,7 +80,11 @@ def flatten_all_paths(
group_filter=lambda x: True, group_filter=lambda x: True,
path_filter=lambda x: True, path_filter=lambda x: True,
path_conversions=CONVERSIONS): path_conversions=CONVERSIONS):
"""Returns the paths inside a group (recursively), expressing the paths in the root coordinates. """Returns the paths inside a group (recursively), expressing the paths in the base coordinates.
Note that if the group being passed in is nested inside some parent group(s), we cannot take the parent group(s)
into account, because xml.etree.Element has no pointer to its parent. You should use Document.flatten_group(group)
to flatten a specific nested group into the root coordinates.
Args: Args:
group is an Element group is an Element
@ -128,6 +133,38 @@ def flatten_all_paths(
return paths return paths
def flatten_group(
group_to_flatten,
root,
recursive=True,
group_filter=lambda x: True,
path_filter=lambda x: True,
path_conversions=CONVERSIONS):
"""Flatten all the paths in a specific group.
The paths will be flattened into the 'root' frame. Note that root needs to be
an ancestor of the group that is being flattened. Otherwise, no paths will be returned."""
if not any(group_to_flatten is descendant for descendant in root.iter()):
warnings.warn('The requested group_to_flatten is not a descendant of root')
# We will shortcut here, because it is impossible for any paths to be returned anyhow.
return []
# We create a set of the unique IDs of each group object that we want to flatten.
# Any groups outside of this set will be skipped while we flatten the paths.
desired_groups = set()
if recursive:
for group in group_to_flatten.iter():
desired_groups.add(id(group))
else:
desired_groups.add(id(group_to_flatten))
def desired_group_filter(x):
return (id(x) in desired_groups) and group_filter(x)
return flatten_all_paths(root, desired_group_filter, path_filter, path_conversions)
class Document: class Document:
def __init__(self, filename): def __init__(self, filename):
"""A container for a DOM-style SVG document. """A container for a DOM-style SVG document.
@ -167,6 +204,21 @@ class Document:
path_conversions=CONVERSIONS): path_conversions=CONVERSIONS):
return flatten_all_paths(self.tree.getroot(), group_filter, path_filter, path_conversions) return flatten_all_paths(self.tree.getroot(), group_filter, path_filter, path_conversions)
def flatten_group(self,
group,
recursive=True,
group_filter=lambda x: True,
path_filter=lambda x: True,
path_conversions=CONVERSIONS):
if all(isinstance(s, basestring) for s in group):
# If we're given a list of strings, assume it represents a nested sequence
group = self.get_or_add_group(group)
elif not isinstance(group, Element):
raise TypeError('Must provide a list of strings that represent a nested group name, '
'or provide an xml.etree.Element object')
return flatten_group(group, self.tree.getroot(), recursive, group_filter, path_filter, path_conversions)
def get_elements_by_tag(self, tag): def get_elements_by_tag(self, tag):
"""Returns a generator of all elements with the given tag. """Returns a generator of all elements with the given tag.
@ -175,18 +227,15 @@ class Document:
""" """
return self.tree.iter(tag=self._prefix + tag) return self.tree.iter(tag=self._prefix + tag)
def convert_pathlike_elements_to_paths(self, conversions=CONVERSIONS):
raise NotImplementedError
def get_svg_attributes(self): def get_svg_attributes(self):
"""To help with backwards compatibility.""" """To help with backwards compatibility."""
return self.get_elements_by_tag('svg')[0].attrib return self.get_elements_by_tag('svg')[0].attrib
def get_path_attributes(self): def get_path_attributes(self):
"""To help with backwards compatibility.""" """To help with backwards compatibility."""
return [p.tree_element.attrib for p in self.paths] return [p.tree_element.attrib for p in self.tree.getroot().iter('path')]
def add_path(self, path, attribs={}, group=None): def add_path(self, path, attribs=None, group=None):
"""Add a new path to the SVG.""" """Add a new path to the SVG."""
# If we are not given a parent, assume that the path does not have a group # If we are not given a parent, assume that the path does not have a group
@ -200,18 +249,32 @@ class Document:
elif not isinstance(group, Element): elif not isinstance(group, Element):
raise TypeError('Must provide a list of strings or an xml.etree.Element object') raise TypeError('Must provide a list of strings or an xml.etree.Element object')
# TODO: If the user passes in an xml.etree.Element object, should we check to make sure that it actually else:
# belongs to this Document object? # Make sure that the group belongs to this Document object
if not self.contains_group(group):
warnings.warn('The requested group does not belong to this Document')
if isinstance(path, Path): if isinstance(path, Path):
path_svg = path.d() path_svg = path.d()
elif is_path_segment(path): elif is_path_segment(path):
path_svg = Path(path).d() path_svg = Path(path).d()
elif isinstance(path, basestring): elif isinstance(path, basestring):
# Assume this is a valid d-string TODO: Should we sanity check the input string? # Assume this is a valid d-string. TODO: Should we sanity check the input string?
path_svg = path path_svg = path
else:
raise TypeError('Must provide a Path, a path segment type, or a valid SVG path d-string')
return SubElement(group, 'path', {'d': path_svg}) if attribs is None:
attribs = {}
else:
attribs = attribs.copy()
attribs['d'] = path_svg
return SubElement(group, 'path', attribs)
def contains_group(self, group):
return any(group is owned for owned in self.tree.iter())
def get_or_add_group(self, nested_names): def get_or_add_group(self, nested_names):
"""Get a group from the tree, or add a new one with the given name structure. """Get a group from the tree, or add a new one with the given name structure.
@ -228,9 +291,10 @@ class Document:
while nested_names: while nested_names:
prev_group = group prev_group = group
next_name = nested_names.pop(0) next_name = nested_names.pop(0)
for elem in group.iter(): for elem in group.iterfind('g'):
if elem.get('id') == next_name: if elem.get('id') == next_name:
group = elem group = elem
break
if prev_group is group: if prev_group is group:
# The group we're looking for does not exist, so let's create the group structure # The group we're looking for does not exist, so let's create the group structure
@ -244,31 +308,32 @@ class Document:
return group return group
def add_group(self, group_attribs={}, parent=None): def add_group(self, group_attribs=None, parent=None):
"""Add an empty group element to the SVG.""" """Add an empty group element to the SVG."""
if parent is None: if parent is None:
parent = self.tree.getroot() parent = self.tree.getroot()
raise NotImplementedError elif not self.contains_group(parent):
warnings.warn('The requested group does not belong to this Document')
def update_tree(self): if group_attribs is None:
"""Rewrite d-string's for each path in the `tree` attribute.""" group_attribs = {}
raise NotImplementedError else:
group_attribs = group_attribs.copy()
def save(self, filename, update=True): return SubElement(parent, 'g', group_attribs)
"""Write to svg to a file."""
if update: def save(self, filename=None):
self.update_tree() if filename is None:
filename = self.original_filename
with open(filename, 'w') as output_svg: with open(filename, 'w') as output_svg:
output_svg.write(etree.tostring(self.tree.getroot())) output_svg.write(etree.tostring(self.tree.getroot()))
def display(self, filename=None, update=True): def display(self, filename=None):
"""Displays/opens the doc using the OS's default application.""" """Displays/opens the doc using the OS's default application."""
if update:
self.update_tree()
if filename is None: if filename is None:
raise NotImplementedError filename = self.original_filename
# write to a (by default temporary) file # write to a (by default temporary) file
with open(filename, 'w') as output_svg: with open(filename, 'w') as output_svg: