Source code for qumin.clustering.node

# -*- coding: utf-8 -*-
# !/usr/bin/env python3

try:
    from matplotlib import pyplot

    plt = pyplot
    matplotlib_loaded = True
except:
    matplotlib_loaded = False

try:
    import networkx as nx

    nx_loaded = True
except:
    nx_loaded = False

import logging
import re
from collections import defaultdict

import numpy as np

log = logging.getLogger()


[docs] class Node(object): """Represent an inflection class tree or lattice. Attributes: labels (list): labels of all the leaves under this node. children (list): direct children of this node. attributes (dict): attributes for this node. Currently, three attributes are expected: size (int): size of the group represented by this node. DL (float): Description length for this node. color (str): color of the splines from this node to its children, in a format usable by pyplot. Currently, red ("r") is used when the node didn't decrease Description length, blue ("b") otherwise. macroclass (bool): Is the node in a macroclass ? macroclass_root (bool): Is the node the root of a macroclass ? The attributes "_x" and "_rank" are reserved, and will be overwritten by the draw function. """
[docs] def __init__(self, labels, children=None, **kwargs): """Node constructor. Arguments: labels (Iterable): labels of all the leaves under this node. children (list): direct children of this node. **kwargs: any other keyword argument will be added as node attributes. Note that certain algorithm expect the Node to have (int) "size", (str) "color", (bool) "macroclass", or (float) "DL" attributes. Note: The attributes "_x" and "_rank" are reserved, and will be overwritten by the draw function. """ self.labels = sorted(labels) self.children = children if children else [] self.attributes = kwargs self.istree = self._test_if_tree()
def _test_if_tree(self): """ Test if this is a tree by checking in-degree.""" parents = {} for node in self: for child in node.children: if child in parents: return False parents[child] = node return True def __str__(self): """Return a repr string for Nodes.""" attrs = " - ".join( "{}={}".format(key, self.attributes[key]) for key in self.attributes) return "< Node object - " + ", ".join(self.labels) + " - " + attrs + ">"
[docs] def macroclasses(self, parent_is_macroclass=False): """Find all the macroclasses nodes in this tree""" self_is_macroclass = self.attributes["macroclass"] if not parent_is_macroclass and self_is_macroclass: labels = self.labels return {labels[0]: labels} elif self.children: macroclasses_under = {} for child in self.children: child_macroclasses = child.macroclasses( parent_is_macroclass=self_is_macroclass) macroclasses_under.update(child_macroclasses) return macroclasses_under return {}
def _recursive_xy(self, ticks, node_spacing, max_x, y_factor=1): if self.attributes.get("_y", None) is None: half_step = node_spacing // 2 if len(self.children) > 0: xs, ys = zip( *[child._recursive_xy(ticks, node_spacing, max_x, y_factor) for child in self.children]) self.height = max([child.height for child in self.children]) + 1 y = max(ys) + (y_factor ** (self.height - 1)) xs = sorted(xs) x = xs[0] + ((xs[-1] - xs[0]) / 2) if y in ticks: # If the preferred value is far enough, pick it min_dist = min(abs(x - x2) for x2 in ticks[y]) if min_dist < half_step: # We prefer candidates in the node span candidates = np.arange(xs[0] - half_step, xs[-1] + half_step, half_step).tolist() # Pick the candidate which is the further from existing points # And closest to the preferred center point candidates = [(x1, min(abs(x1 - x2) for x2 in ticks[y])) for x1 in candidates] x, min_dist = max(candidates, key=lambda x: x[1]) if min_dist < half_step: # Fallback on more points candidates = np.arange(0 - half_step, xs[0] - half_step, half_step).tolist() \ + np.arange(xs[-1] + half_step, max_x + half_step, half_step).tolist() # Pick the candidate which is the further from existing points # And closest to the preferred center point candidates = [(x1, min(abs(x1 - x2) for x2 in ticks[y]), max_x - abs(x1 - x)) for x1 in candidates] x, min_dist, center_dist = max(candidates, key=lambda x: x[1]) ticks[y].append(x) else: ticks[y] = [x] self.attributes["_x"] = x else: y = 1 self.height = 1 self.attributes["_y"] = y return self.attributes["_x"], self.attributes["_y"] def _erase_xy(self): if "_x" in self.attributes: del self.attributes["_x"] if "_y" in self.attributes: del self.attributes["_y"] for child in self.children: child._erase_xy() def _compute_xy(self, layout="qumin", pos=None, y_factor=1): graphviz_layout = nx.drawing.nx_agraph.graphviz_layout nx_layouts = {"dot": lambda x: graphviz_layout(x, prog="dot"), # "spring":nx.drawing.spring_layout, # "kamada_kawai":nx.drawing.kamada_kawai_layout, # "radial": lambda x: graphviz_layout(x, prog="twopi"), } if "_y" in self.attributes and self.attributes["_y"] is not None: self._erase_xy() if layout == "qumin": # For trees leaves_ordered = self._sort_leaves() x = 0 step = 30 for leaf in leaves_ordered: leaf.attributes["_x"] = x x += step self._recursive_xy({}, step, x, y_factor) elif layout in nx_layouts: infimum = Node(["infimum"]) layout_f = nx_layouts[layout] # add infimum to improve node placement leaves = list(self.leaves()) for i, leaf in enumerate(leaves): leaf.children.append(infimum) if pos is None: pos = layout_f(self.to_networkx()) for leaf in leaves: leaf.children = [] for node in self: # annotate x and y node.attributes["_x"], node.attributes["_y"] = pos[tuple(node.labels)] else: raise NotImplementedError("Layout method {} is not implemented." "pick from: qumin, {}".format(layout, ", ".join( nx_layouts))) def _sort_leaves(self): """Sorts leaves by similarity for plotting. This is a greedy tsp implementation. """ leaves = list(self.leaves()) li = len(leaves) similarities = np.zeros((li, li)) ancestors = defaultdict(set) for node in self: for l in node.labels: ancestors[(l,)].add(tuple(node.labels)) for i, leaf in enumerate(leaves): for j, leaf2 in enumerate(leaves): if i != j: a1 = ancestors[tuple(leaf.labels)] a2 = ancestors[tuple(leaf2.labels)] united = a1 | a2 jaccard = len(a1 & a2) / len(united) if united else 0 similarities[i, j] = similarities[j, i] = jaccard paths = {i: [i] for i in range(li)} def shortest_path(sim): return np.unravel_index(np.argmax(sim, axis=None), sim.shape) i = None for x in range(li - 1): i, j = shortest_path(similarities) i_start = paths[i][0] j_end = paths[j][-1] similarities[i, :] = float("-inf") similarities[:, j] = float("-inf") similarities[j_end, i_start] = float("-inf") # don't loop res = paths[i] + paths[j] for k in res: paths[k] = res assert len(paths[i]) == li return [leaves[k] for k in paths[i]]
[docs] def to_networkx(self): if not nx_loaded: raise ImportError("Can't convert to networkx, because it couldn't be loaded.") G = nx.DiGraph() for node in self: name = tuple(node.labels) G.add_node(name, **node.attributes) for child in node.children: childname = tuple(child.labels) G.add_edge(name, childname) return G
def __eq__(self, other): return (self.labels == other.labels) & (set(self.children) == set(other.children)) def __repr__(self): rules = [ str(node.labels) + " -> " + " ".join(str(c.labels) for c in sorted(node.children, key=lambda x: x.labels)) if node.children else str(node.labels) for node in self] return "\n".join(sorted(rules)) def __hash__(self): return hash(repr(self)) def __iter__(self): agenda = [self] nodes = [] while agenda: node = agenda.pop(0) if not node.attributes.get("visited", False): node.attributes["visited"] = True agenda = node.children + agenda nodes.append(node) # cleanup for n in nodes: del n.attributes["visited"] yield from nodes
[docs] def leaves(self): return filter(lambda x: not bool(x.children), self)
[docs] def tree_string(self): """Return the inflection class tree as a string with parenthesis. Assumes size, DL and color attributes, with color = "r" if this is above a macroclass. Example: In the label, fields are separated by "#" as such:: (<labels>#<size>#<DL>#<color> ) """ if not self.istree: raise NotImplementedError( "Tree string is only possible with trees, this looks like a graph.") labels = "&".join(self.labels) ignore = ["_rank", "_x", "_y", "macroclass", "macroclass_root", "point_settings"] attributes = "#".join( "{}={}".format(key, str(self.attributes[key]).replace(" ", "_")) for key in sorted(self.attributes) if key not in ignore) children_str = [child.tree_string() for child in self.children] string = "(" + labels + "#" + attributes + " ".join([""] + children_str) + " )" return string
[docs] def to_tikz(self, leavesfunc=None, nodefunc=None, layout="qumin", pos=None, ratio=(1, 1), width=20, color_attrs=None, y_factor=.8): color_attrs = color_attrs or [] def attribute_table(node): table = [] common = len(node.attributes.get("common", [])) > 0 if not node.children: table.append(node.labels[0]) if common: table.append("\\\\") if common: table.append("\\begin{tabular}{r@{ }l}") for feature in node.attributes["common"]: a, b = feature.split("=") table.append("\\textsc{" + a + "} & " + b + " \\\\") table.append("\\end{tabular}") return "".join(table) def scale(m, max_obs, max_target): return round((m / max_obs) * max_target, 2) nodefunc = nodefunc or attribute_table leavesfunc = leavesfunc or attribute_table self._compute_xy(pos=pos, layout=layout, y_factor=y_factor) node_template = "\\node ({name}) at ({x},{y}) [draw=none] {{}};" node_label = "\\node ({parent}-label) at ({parent}.south) [anchor={anchor_self},align=center{style}] {{{label}}};" edge_template = "\\draw ({a}-label.{anchor_a}) edge[-] ({b}.{anchor_b});" xys = list(zip(*[(n.attributes["_x"], n.attributes["_y"]) for n in self])) max_x = max(xys[0]) max_y = max(xys[1]) height = width * (ratio[1] / ratio[0]) lines = [] for i, node in enumerate(self): label_f = nodefunc if node.children else leavesfunc style = ",draw=none,fill=white,fill opacity=.8" if not node.attributes.get("common", False): style = ",draw=none,fill=none" if color_attrs and \ any([f.split("=")[0] in color_attrs for f in node.attributes.get("common", [])]): style = ",fill=gray!20" node_filled = node_template.format(name=str(i), x=scale(node.attributes["_x"], max_x, width), y=scale(node.attributes["_y"], max_y, height)) node_text = node_label.format(parent=str(i), label=label_f(node), style=style, anchor_self="north", ) lines.append(node_filled) lines.append(node_text) node.attributes["tikz-label"] = str(i) lines.append("\\begin{pgfonlayer}{background}") for node in self: a = node.attributes["tikz-label"] for child in node.children: lines.append(edge_template.format(a=a, b=child.attributes["tikz-label"], anchor_a="south", anchor_b="north")) lines.append("\\end{pgfonlayer}") return "\n".join(lines)
[docs] def draw(self, horizontal=False, square=False, leavesfunc=lambda n: n.labels[0], nodefunc=None, label_rotation=None, annotateOnlyMacroclasses=False, point=None, edge_attributes=None, interactive=False, layout=False, pos=None, **kwargs): """Draw the tree as a dendrogram-style pyplot graph. Example:: square=True square=False │ ┌──┴──┐ │ ╱╲ horizontal=False │ │ ┌─┴─┐ │ ╱ ╲ │ │ │ │ │ ╱ ╱╲ │ │ │ │ │ ╱ ╱ ╲ │__│___│___│ │╱___╱____╲ │─────┐ │⟍ │───┐ ├ │ ⟍ horizontal=True │ ├─┘ │⟍ ⟋ │───┘ │⟋ │____________ │____________ Arguments: horizontal (bool): Should the tree be drawn with leaves on the y axis ? (Defaults to False: leaves on x axis). square (bool): Should the tree splines be squared with 90° angles ? (Defauls to False) leavesfunc (Callable): A function that will be applied to leaves before writing them down. Takes a Node, returns a str. nodefunc (Callable): A function that will be applied to nodes to annotate them. Takes a Node, returns a str. keep_above_macroclass (bool): For macroclass history trees: Should the edges above macroclasses be drawn ? (Defaults to True). annotateOnlyMacroclasses : For macroclass history trees: If `True` and nodelabel isn't `None`, only the macroclasses nodes are annotated. point (Callable): A function that maps a node to point attributes. edge_attributes (Callable): A function that maps a pair of nodes to edge attributes. By default, use the parent's color and "-" linestyle for nodes, "--" for leaves. interactive (bool): Whether this is destined to create an interactive plot. layout (bool): layout keyword, either of "qumin" or "dot". Ignored if pos is given. pos (dict): A dictionnary of node label to x,y positions. Compatible with networkx layout functions. If absent, use networkx's graphviz layout. """ self._compute_xy(pos=pos, layout=layout) if not matplotlib_loaded: return str(self) else: def annotate(node): should_annotate = (not annotateOnlyMacroclasses) or \ node.attributes["macroclass_root"] if should_annotate and nodefunc is not None: return str(nodefunc(node)) return "" def default_edge_attributes(node, child): attributes = {"linestyle": "-" if node.children else "--", "color": node.attributes.get("color", "#333333")} if edge_attributes is not None: attributes.update(edge_attributes(node, child)) return attributes if horizontal: textoffset = (5, 0) lva = "center" lha = "right" va = "center" ha = "left" r = 0 def coords(node): return node.attributes["_y"], node.attributes["_x"] else: lva = "top" lha = "center" va = "bottom" ha = "center" textoffset = (0, 5) r = 45 def coords(node): return node.attributes["_x"], node.attributes["_y"] if label_rotation is not None: r = label_rotation ax = plt.gca() # bg = ax.patch.get_facecolor() lines = [] all_nodes = [] for node in self: this_x, this_y = coords(node) # Plot the arcs for child in node.children: child_x, child_y = coords(child) attr = default_edge_attributes(node, child) if square: if horizontal: l = ax.plot((this_x, this_x, child_x), (this_y, child_y, child_y), **attr) else: l = ax.plot((this_x, child_x, child_x), (this_y, this_y, child_y), **attr) else: l = ax.plot((this_x, child_x), (this_y, child_y), **attr) lines.extend(l) # Plot the point if point is not None: coll = ax.scatter((this_x,), (this_y,), **point(node)) lines.append(coll) # Write annotations if node.labels: # tmp = node.labels # node.labels = list(microclass) f = nodefunc if node.children else leavesfunc plt.annotate(f(node), xy=(this_x, this_y), xycoords='data', va=lva, ha=lha, rotation=r) # node.labels = tmp else: text = annotate(node) if text is not None: plt.annotate(text, xy=(this_x, this_y), va=va, ha=ha, textcoords='offset points', xytext=textoffset) all_nodes.append(node) # Scale axes ax.autoscale() if interactive: if horizontal: ax.margins(x=0.3, y=0.1) else: ax.margins(y=0.3, x=0.1) plt.tick_params( axis='both', # changes apply to the x-axis which='both', # both major and minor ticks top='off', # ticks along the top edge bottom='off', # ticks along the bottom edge right='off', # ticks along the right edge left='off', # ticks along the left edge labelbottom='off' ) plt.yticks([], []) plt.xticks([], []) return lines, all_nodes
[docs] def string_to_node(string, legacy_annotation_name=None): """Parse an inflection tree written as a string. Example: In the label, fields are separated by "#" as such:: (<labels>#<size>#<DL>#<color> (... ) (... ) ) Returns: Node: The root of the tree """ legacy = False if "=" not in string: legacy = True def parse_node(line): splitted = line[1:].split("#") labels = splitted[0] attributes = dict(attr.split("=") for attr in splitted[1:]) return labels, attributes def parse_node_legacy(item): item = item[1:].split("#") return item stack = [] # plus robuste que de splitter sur l'espace, autorise les espaces dans les lexèmes & attributs items = re.split(" (?=[()])", string) if legacy_annotation_name: annotation_name = legacy_annotation_name else: annotation_name = "DL" for item in items: if item[0] == "(": if legacy: # Backward compatibility mode labels, size, annotation, color = parse_node_legacy(item) if not color: color = "c" if not annotation: annotation = "" attributes = {"size": float(size), annotation_name: annotation, "color": color, "macroclass": color == "r"} else: labels, attributes = parse_node(item) if 'color' not in attributes: attributes['color'] = "c" attributes['macroclass'] = attributes['color'] != 'r' labels = sorted(labels.split("&")) attributes['macroclass_root'] = False if annotation_name in attributes: try: attributes[annotation_name] = float(attributes[annotation_name]) except ValueError: pass # only convert numbers stack.append(Node(labels, **attributes)) if item[0] == ")": if len(item) > 1: log.warning("Warning, bad format ! #{}#".format(item)) if len(stack) > 1: child = stack.pop(-1) parent = stack[-1] # Macroclass are one level below the red : child.attributes['macroclass_root'] = parent.attributes[ 'color'] == 'r' and \ child.attributes['color'] != 'r' stack[-1].children.append(child) else: return stack[0] if len(stack) > 1: log.warning("unmatched parenthesis or no root ! " + str(stack)) log.info(stack[0]) return stack[0]