Class Diagrams

This is a simple viewer for class diagrams. Customized towards the book.

Prerequisites

Synopsis

To use the code provided in this chapter, write

>>> from fuzzingbook.ClassDiagram import <identifier>

and then make use of the following features.

The function display_class_hierarchy() function shows the class hierarchy for the given class (or list of classes).

  • The keyword parameter public_methods, if given, is a list of "public" methods to be used by clients (default: all methods with docstrings).
  • The keyword parameter abstract_classes, if given, is a list of classes to be displayed as "abstract" (i.e. with a cursive class name).
>>> display_class_hierarchy(D_Class, abstract_classes=[A_Class])

D_Class D_Class foo() B_Class B_Class VAR bar() foo() D_Class->B_Class C_Class C_Class qux() D_Class->C_Class A_Class A_Class foo() quux() second() B_Class->A_Class Legend Legend •  public_method() •  private_method() •  overloaded_method() Hover over names to see doc

Getting a Class Hierarchy

import inspect

Using mro(), we can access the class hierarchy. We make sure to avoid duplicates created by class X(X).

def class_hierarchy(cls: Type) -> List[Type]:
    superclasses = cls.mro()
    hierarchy = []
    last_superclass_name = ""

    for superclass in superclasses:
        if superclass.__name__ != last_superclass_name:
            hierarchy.append(superclass)
            last_superclass_name = superclass.__name__

    return hierarchy

Here's an example:

class A_Class:
    """A Class which does A thing right.
    Comes with a longer docstring."""

    def foo(self) -> None:
        """The Adventures of the glorious Foo"""
        pass

    def quux(self) -> None:
        """A method that is not used."""
        pass
class A_Class(A_Class):
    # We define another function in a separate cell.

    def second(self) -> None:
        pass
class B_Class(A_Class):
    """A subclass inheriting some methods."""

    VAR = "A variable"

    def foo(self) -> None:
        """A WW2 foo fighter."""
        pass

    def bar(self, qux: Any = None, bartender: int = 42) -> None:
        """A qux walks into a bar.
        `bartender` is an optional attribute."""
        pass
SomeType = List[Optional[Union[str, int]]]
class C_Class:
    """A class injecting some method"""

    def qux(self, arg: SomeType) -> SomeType:
        return arg
class D_Class(B_Class, C_Class):
    """A subclass inheriting from multiple superclasses.
    Comes with a fairly long, but meaningless documentation."""

    def foo(self) -> None:
        B_Class.foo(self)
class D_Class(D_Class):
    pass  # An incremental addiiton that should not impact D's semantics
class_hierarchy(D_Class)
[__main__.D_Class,
 __main__.B_Class,
 __main__.A_Class,
 __main__.C_Class,
 object]

Getting a Class Tree

We can use __bases__ to obtain the immediate base classes.

D_Class.__bases__
(__main__.D_Class,)

class_tree() returns a class tree, using the "lowest" (most specialized) class with the same name.

def class_tree(cls: Type, lowest: Optional[Type] = None) -> List[Tuple[Type, List]]:
    ret = []
    for base in cls.__bases__:
        if base.__name__ == cls.__name__:
            if not lowest:
                lowest = cls
            ret += class_tree(base, lowest)
        else:
            if lowest:
                cls = lowest
            ret.append((cls, class_tree(base)))

    return ret
class_tree(D_Class)
[(__main__.D_Class, [(__main__.B_Class, [(__main__.A_Class, [])])]),
 (__main__.D_Class, [(__main__.C_Class, [])])]
class_tree(D_Class)[0][0]
__main__.D_Class
assert class_tree(D_Class)[0][0] == D_Class

class_set() flattens the tree into a set:

def class_set(classes: Union[Type, List[Type]]) -> Set[Type]:
    if not isinstance(classes, list):
        classes = [classes]

    ret = set()

    def traverse_tree(tree: List[Tuple[Type, List]]) -> None:
        for (cls, subtrees) in tree:
            ret.add(cls)
            for subtree in subtrees:
                traverse_tree(subtrees)

    for cls in classes:
        traverse_tree(class_tree(cls))

    return ret
class_set(D_Class)
{__main__.A_Class, __main__.B_Class, __main__.C_Class, __main__.D_Class}
assert A_Class in class_set(D_Class)
assert B_Class in class_set(D_Class)
assert C_Class in class_set(D_Class)
assert D_Class in class_set(D_Class)
class_set([B_Class, C_Class])
{__main__.A_Class, __main__.B_Class, __main__.C_Class}

Getting Docs

A_Class.__doc__
A_Class.__bases__[0].__doc__
'A Class which does A thing right.\n    Comes with a longer docstring.'
A_Class.__bases__[0].__name__
'A_Class'
D_Class.foo
<function __main__.D_Class.foo(self) -> None>
D_Class.foo.__doc__
A_Class.foo.__doc__
'The Adventures of the glorious Foo'
def docstring(obj: Any) -> str:
    doc = inspect.getdoc(obj)
    return doc if doc else ""
docstring(A_Class)
'A Class which does A thing right.\nComes with a longer docstring.'
docstring(D_Class.foo)
'A WW2 foo fighter.'
def unknown() -> None:
    pass
docstring(unknown)
''
import html
import re
def escape(text: str) -> str:
    text = html.escape(text)
    assert '<' not in text
    assert '>' not in text
    text = text.replace('{', '&#x7b;')
    text = text.replace('|', '&#x7c;')
    text = text.replace('}', '&#x7d;')
    return text
escape("f(foo={})")
'f(foo=&#x7b;&#x7d;)'
def escape_doc(docstring: str) -> str:
    DOC_INDENT = 0
    docstring = "&#x0a;".join(
        ' ' * DOC_INDENT + escape(line).strip()
        for line in docstring.split('\n')
    )
    return docstring
print(escape_doc("'Hello\n    {You|Me}'"))
&#x27;Hello&#x0a;&#x7b;You&#x7c;Me&#x7d;&#x27;

Getting Methods and Variables

inspect.getmembers(D_Class)
[('VAR', 'A variable'),
 ('__class__', type),
 ('__delattr__', <slot wrapper '__delattr__' of 'object' objects>),
 ('__dict__', mappingproxy({'__module__': '__main__', '__doc__': None})),
 ('__dir__', <method '__dir__' of 'object' objects>),
 ('__doc__', None),
 ('__eq__', <slot wrapper '__eq__' of 'object' objects>),
 ('__format__', <method '__format__' of 'object' objects>),
 ('__ge__', <slot wrapper '__ge__' of 'object' objects>),
 ('__getattribute__', <slot wrapper '__getattribute__' of 'object' objects>),
 ('__gt__', <slot wrapper '__gt__' of 'object' objects>),
 ('__hash__', <slot wrapper '__hash__' of 'object' objects>),
 ('__init__', <slot wrapper '__init__' of 'object' objects>),
 ('__init_subclass__', <function D_Class.__init_subclass__>),
 ('__le__', <slot wrapper '__le__' of 'object' objects>),
 ('__lt__', <slot wrapper '__lt__' of 'object' objects>),
 ('__module__', '__main__'),
 ('__ne__', <slot wrapper '__ne__' of 'object' objects>),
 ('__new__', <function object.__new__(*args, **kwargs)>),
 ('__reduce__', <method '__reduce__' of 'object' objects>),
 ('__reduce_ex__', <method '__reduce_ex__' of 'object' objects>),
 ('__repr__', <slot wrapper '__repr__' of 'object' objects>),
 ('__setattr__', <slot wrapper '__setattr__' of 'object' objects>),
 ('__sizeof__', <method '__sizeof__' of 'object' objects>),
 ('__str__', <slot wrapper '__str__' of 'object' objects>),
 ('__subclasshook__', <function D_Class.__subclasshook__>),
 ('__weakref__', <attribute '__weakref__' of 'A_Class' objects>),
 ('bar',
  <function __main__.B_Class.bar(self, qux: Any = None, bartender: int = 42) -> None>),
 ('foo', <function __main__.D_Class.foo(self) -> None>),
 ('quux', <function __main__.A_Class.quux(self) -> None>),
 ('qux',
  <function __main__.C_Class.qux(self, arg: List[Union[str, int, NoneType]]) -> List[Union[str, int, NoneType]]>),
 ('second', <function __main__.A_Class.second(self) -> None>)]
def class_items(cls: Type, pred: Callable) -> List[Tuple[str, Any]]:
    def _class_items(cls: Type) -> List:
        all_items = inspect.getmembers(cls, pred)
        for base in cls.__bases__:
            all_items += _class_items(base)

        return all_items

    unique_items = []
    items_seen = set()
    for (name, item) in _class_items(cls):
        if name not in items_seen:
            unique_items.append((name, item))
            items_seen.add(name)

    return unique_items
def class_methods(cls: Type) -> List[Tuple[str, Callable]]:
    return class_items(cls, inspect.isfunction)
def defined_in(name: str, cls: Type) -> bool:
    if not hasattr(cls, name):
        return False

    defining_classes = []

    def search_superclasses(name: str, cls: Type) -> None:
        if not hasattr(cls, name):
            return

        for base in cls.__bases__:
            if hasattr(base, name):
                defining_classes.append(base)
                search_superclasses(name, base)

    search_superclasses(name, cls)

    if any(cls.__name__ != c.__name__ for c in defining_classes):
        return False  # Already defined in superclass

    return True
assert not defined_in('VAR', A_Class)
assert defined_in('VAR', B_Class)
assert not defined_in('VAR', C_Class)
assert not defined_in('VAR', D_Class)
def class_vars(cls: Type) -> List[Any]:
    def is_var(item: Any) -> bool:
        return not callable(item)

    return [item for item in class_items(cls, is_var) 
            if not item[0].startswith('__') and defined_in(item[0], cls)]
class_methods(D_Class)
[('bar',
  <function __main__.B_Class.bar(self, qux: Any = None, bartender: int = 42) -> None>),
 ('foo', <function __main__.D_Class.foo(self) -> None>),
 ('quux', <function __main__.A_Class.quux(self) -> None>),
 ('qux',
  <function __main__.C_Class.qux(self, arg: List[Union[str, int, NoneType]]) -> List[Union[str, int, NoneType]]>),
 ('second', <function __main__.A_Class.second(self) -> None>)]
class_vars(B_Class)
[('VAR', 'A variable')]

We're only interested in

  • functions defined in that class
  • functions that come with a docstring
def public_class_methods(cls: Type) -> List[Tuple[str, Callable]]:
    return [(name, method) for (name, method) in class_methods(cls) 
            if method.__qualname__.startswith(cls.__name__)]
def doc_class_methods(cls: Type) -> List[Tuple[str, Callable]]:
    return [(name, method) for (name, method) in public_class_methods(cls) 
            if docstring(method) is not None]
public_class_methods(D_Class)
[('foo', <function __main__.D_Class.foo(self) -> None>)]
doc_class_methods(D_Class)
[('foo', <function __main__.D_Class.foo(self) -> None>)]
def overloaded_class_methods(classes: Union[Type, List[Type]]) -> Set[str]:
    all_methods: Dict[str, Set[Callable]] = {}
    for cls in class_set(classes):
        for (name, method) in class_methods(cls):
            if method.__qualname__.startswith(cls.__name__):
                all_methods.setdefault(name, set())
                all_methods[name].add(cls)

    return set(name for name in all_methods if len(all_methods[name]) >= 2)
overloaded_class_methods(D_Class)
{'foo'}

Drawing Class Hierarchy with Method Names

from inspect import signature
import warnings
import os
def display_class_hierarchy(classes: Union[Type, List[Type]], *,
                            public_methods: Optional[List] = None,
                            abstract_classes: Optional[List] = None,
                            include_methods: bool = True,
                            include_class_vars: bool = True,
                            include_legend: bool = True,
                            local_defs_only: bool = True,
                            types: Dict[str, Any] = {},
                            project: str = 'fuzzingbook',
                            log: bool = False) -> Any:
    """Visualize a class hierarchy.
`classes` is a Python class (or a list of classes) to be visualized.
`public_methods`, if given, is a list of methods to be shown as "public" (bold).
  (Default: all methods with a docstring)
`abstract_classes`, if given, is a list of classes to be shown as "abstract" (cursive).
  (Default: all classes with an abstract method)
`include_methods`: if set (default), include all methods
`include_legend`: if set (default), include a legend
`local_defs_only`: if set (default), hide details of imported classes
`types`: type names with definitions, to be used in docs
    """
    from graphviz import Digraph

    if project == 'debuggingbook':
        CLASS_FONT = 'Raleway, Helvetica, Arial, sans-serif'
        CLASS_COLOR = '#6A0DAD'  # HTML 'purple'
    else:
        CLASS_FONT = 'Patua One, Helvetica, sans-serif'
        CLASS_COLOR = '#B03A2E'

    METHOD_FONT = "'Fira Mono', 'Source Code Pro', 'Courier', monospace"
    METHOD_COLOR = 'black'

    if isinstance(classes, list):
        starting_class = classes[0]
    else:
        starting_class = classes
        classes = [starting_class]

    title = starting_class.__name__ + " class hierarchy"

    dot = Digraph(comment=title)
    dot.attr('node', shape='record', fontname=CLASS_FONT)
    dot.attr('graph', rankdir='BT', tooltip=title)
    dot.attr('edge', arrowhead='empty')
    edges = set()
    overloaded_methods: Set[str] = set()

    drawn_classes = set()

    def method_string(method_name: str, public: bool, overloaded: bool,
                      fontsize: float = 10.0) -> str:
        method_string = f'<font face="{METHOD_FONT}" point-size="{str(fontsize)}">'

        if overloaded:
            name = f'<i>{method_name}()</i>'
        else:
            name = f'{method_name}()'

        if public:
            method_string += f'<b>{name}</b>'
        else:
            method_string += f'<font color="{METHOD_COLOR}">' \
                             f'{name}</font>'

        method_string += '</font>'
        return method_string

    def var_string(var_name: str, fontsize: int = 10) -> str:
        var_string = f'<font face="{METHOD_FONT}" point-size="{str(fontsize)}">'
        var_string += f'{var_name}'
        var_string += '</font>'
        return var_string

    def is_overloaded(method_name: str, f: Any) -> bool:
        return (method_name in overloaded_methods or
                (docstring(f) is not None and "in subclasses" in docstring(f)))

    def is_abstract(cls: Type) -> bool:
        if not abstract_classes:
            return inspect.isabstract(cls)

        return (cls in abstract_classes or
                any(c.__name__ == cls.__name__ for c in abstract_classes))

    def is_public(method_name: str, f: Any) -> bool:
        if public_methods:
            return (method_name in public_methods or
                    f in public_methods or
                    any(f.__qualname__ == m.__qualname__
                        for m in public_methods))

        return bool(docstring(f))

    def frame_module(frameinfo: Any) -> str:
        return os.path.splitext(os.path.basename(frameinfo.frame.f_code.co_filename))[0]

    def callers() -> List[str]:
        frames = inspect.getouterframes(inspect.currentframe())
        return [frame_module(frameinfo) for frameinfo in frames]

    def is_local_class(cls: Type) -> bool:
        return cls.__module__ == '__main__' or cls.__module__ in callers()

    def class_vars_string(cls: Type, url: str) -> str:
        cls_vars = class_vars(cls)
        if len(cls_vars) == 0:
            return ""

        vars_string = f'<table border="0" cellpadding="0" ' \
                      f'cellspacing="0" ' \
                      f'align="left" tooltip="{cls.__name__}" href="#">'

        for (name, var) in cls_vars:
            if log:
                print(f"    Drawing {name}")

            var_doc = escape(f"{name} = {repr(var)}")
            tooltip = f' tooltip="{var_doc}"'
            href = f' href="{url}"'
            vars_string += f'<tr><td align="left" border="0"' \
                           f'{tooltip}{href}>'

            vars_string += var_string(name)
            vars_string += '</td></tr>'

        vars_string += '</table>'
        return vars_string

    def class_methods_string(cls: Type, url: str) -> str:
        methods = public_class_methods(cls)
        # return "<br/>".join([name + "()" for (name, f) in methods])
        methods_string = f'<table border="0" cellpadding="0" ' \
                         f'cellspacing="0" ' \
                         f'align="left" tooltip="{cls.__name__}" href="#">'

        public_methods_only = local_defs_only and not is_local_class(cls)

        methods_seen = False
        for public in [True, False]:
            for (name, f) in methods:
                if public != is_public(name, f):
                    continue

                if public_methods_only and not public:
                    continue

                if log:
                    print(f"    Drawing {name}()")

                if is_public(name, f) and not docstring(f):
                    warnings.warn(f"{f.__qualname__}() is listed as public,"
                                  f" but has no docstring")

                overloaded = is_overloaded(name, f)

                sig = str(inspect.signature(f))
                # replace 'List[Union[...]]' by the actual type def
                for tp in types:
                    tp_def = str(types[tp]).replace('typing.', '')
                    sig = sig.replace(tp_def, tp)
                sig = sig.replace('__main__.', '')

                method_doc = escape(name + sig)
                if docstring(f):
                    method_doc += ":&#x0a;" + escape_doc(docstring(f))

                if log:
                    print(f"    Method doc: {method_doc}")

                # Tooltips are only shown if a href is present, too
                tooltip = f' tooltip="{method_doc}"'
                href = f' href="{url}"'
                methods_string += f'<tr><td align="left" border="0"' \
                                  f'{tooltip}{href}>'

                methods_string += method_string(name, public, overloaded)

                methods_string += '</td></tr>'
                methods_seen = True

        if not methods_seen:
            return ""

        methods_string += '</table>'
        return methods_string

    def display_class_node(cls: Type) -> None:
        name = cls.__name__

        if name in drawn_classes:
            return
        drawn_classes.add(name)

        if log:
            print(f"Drawing class {name}")

        if cls.__module__ == '__main__':
            url = '#'
        else:
            url = cls.__module__ + '.ipynb'

        if is_abstract(cls):
            formatted_class_name = f'<i>{cls.__name__}</i>'
        else:
            formatted_class_name = cls.__name__

        if include_methods or include_class_vars:
            vars = class_vars_string(cls, url)
            methods = class_methods_string(cls, url)
            spec = '<{<b><font color="' + CLASS_COLOR + '">' + \
                formatted_class_name + '</font></b>'
            if include_class_vars and vars:
                spec += '|' + vars
            if include_methods and methods:
                spec += '|' + methods
            spec += '}>'
        else:
            spec = '<' + formatted_class_name + '>'

        class_doc = escape('class ' + cls.__name__)
        if docstring(cls):
            class_doc += ':&#x0a;' + escape_doc(docstring(cls))
        else:
            warnings.warn(f"Class {cls.__name__} has no docstring")

        dot.node(name, spec, tooltip=class_doc, href=url)

    def display_class_trees(trees: List[Tuple[Type, List]]) -> None:
        for tree in trees:
            (cls, subtrees) = tree
            display_class_node(cls)

            for subtree in subtrees:
                (subcls, _) = subtree

                if (cls.__name__, subcls.__name__) not in edges:
                    dot.edge(cls.__name__, subcls.__name__)
                    edges.add((cls.__name__, subcls.__name__))

            display_class_trees(subtrees)

    def display_legend() -> None:
        fontsize = 8.0

        label = f'<b><font color="{CLASS_COLOR}">Legend</font></b><br align="left"/>' 

        for item in [
            method_string("public_method",
                          public=True, overloaded=False, fontsize=fontsize),
            method_string("private_method",
                          public=False, overloaded=False, fontsize=fontsize),
            method_string("overloaded_method",
                          public=False, overloaded=True, fontsize=fontsize)
        ]:
            label += '&bull;&nbsp;' + item + '<br align="left"/>'

        label += f'<font face="Helvetica" point-size="{str(fontsize + 1)}">' \
                 'Hover over names to see doc' \
                 '</font><br align="left"/>'

        dot.node('Legend', label=f'<{label}>', shape='plain', fontsize=str(fontsize + 2))

    for cls in classes:
        tree = class_tree(cls)
        overloaded_methods = overloaded_class_methods(cls)
        display_class_trees(tree)

    if include_legend:
        display_legend()

    return dot
display_class_hierarchy(D_Class, types={'SomeType': SomeType},
                        project='debuggingbook', log=True)
Drawing class D_Class
    Drawing foo()
    Method doc: foo(self) -&gt; None:&#x0a;A WW2 foo fighter.
Drawing class B_Class
    Drawing VAR
    Drawing bar()
    Method doc: bar(self, qux: Any = None, bartender: int = 42) -&gt; None:&#x0a;A qux walks into a bar.&#x0a;`bartender` is an optional attribute.
    Drawing foo()
    Method doc: foo(self) -&gt; None:&#x0a;A WW2 foo fighter.
Drawing class A_Class
    Drawing foo()
    Method doc: foo(self) -&gt; None:&#x0a;The Adventures of the glorious Foo
    Drawing quux()
    Method doc: quux(self) -&gt; None:&#x0a;A method that is not used.
    Drawing second()
    Method doc: second(self) -&gt; None
Drawing class C_Class
    Drawing qux()
    Method doc: qux(self, arg: SomeType) -&gt; SomeType
display_class_hierarchy(D_Class, types={'SomeType': SomeType},
                        project='fuzzingbook')

Here is a variant with abstract classes and logging:

display_class_hierarchy([A_Class, B_Class],
                        abstract_classes=[A_Class],
                        public_methods=[
                            A_Class.quux,
                        ],
                        log=True)
Drawing class A_Class
    Drawing quux()
    Method doc: quux(self) -&gt; None:&#x0a;A method that is not used.
    Drawing foo()
    Method doc: foo(self) -&gt; None:&#x0a;The Adventures of the glorious Foo
    Drawing second()
    Method doc: second(self) -&gt; None
Drawing class B_Class
    Drawing VAR
    Drawing bar()
    Method doc: bar(self, qux: Any = None, bartender: int = 42) -&gt; None:&#x0a;A qux walks into a bar.&#x0a;`bartender` is an optional attribute.
    Drawing foo()
    Method doc: foo(self) -&gt; None:&#x0a;A WW2 foo fighter.

Exercises

Enjoy!

Creative Commons License The content of this project is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. The source code that is part of the content, as well as the source code used to format and display that content is licensed under the MIT License. Last change: 2023-11-11 18:25:46+01:00CiteImprint