Entry 45: Visualizing Decision Trees

2 minute read

A major benefit of tree-based models is how easy they are to visualize. This visualization aspect is also vital to discussing how trees work.

The notebook where I did my code for this entry can be found on my github page in the Entry 45 notebook.

The Problem

In order to take full advantage of the benefits of tree-based models, a way to visualize them is needed.

The Options

There are three main options for visualizing a Decision Tree:

  • export_graphviz
  • pydotplus
  • tree.plot_tree

The package and options you want depend on what you want to do with the visualization. If you need to save the image then both export_graphviz and pydotplus will work. To simply visualize the tree in the Jupyter Notebook three packages have options that’ll work.

export_graphviz

Based on the number of tutorials export_graphviz appears to be the more popular option. Or it could just be the option that’s been around longer and thus more well known.

Example:

export_graphviz(tree_clf, out_file='images/graphviz_example.dot',
                feature_names=X_train.columns.tolist(),
                class_names = y_train.unique().tolist(),
                rounded=True,
                filled=True)

The out_file parameter accepts the path/name of the file and will save the image as a .dot file. The file can then be converted to a .png using a command line function. Fortunately, command line functions can be run in Jupyter Notebooks, so you don’t have to switch back and forth.

!dot -Tpng image_name.dot -o image_name.png

If the out_file parameter is set to None it can be saved to a variable and passed to graphviz.Source to be visualized directly to the Jupyter Notebook.

Example:

dot_data = tree.export_graphviz(tree_clf, out_file=None,
               feature_names=X_train.columns.tolist(),
               class_names = y_train.unique().tolist(),
               rounded=True,
               filled=True)

graphviz.Source(dot_data)

pydotplus

Okay, so you need export_graphviz to create the .dot file of the tree, but then you can use pydotplus to create a variable that can be either visualized or saved. To visualize it directly to the Notebook, use the Image function from the IPython.display package. To save it the commands graph.write_dot and graph.write_png come in handy.

graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())

graph.write_png('images/pydotplus_example.png')

Seeing how this can be done with fewer functions using graphviz.Source, pydotplus feels a little extraneous.

tree.plot_tree

Using tree.plot_tree from the sklearn package is probably the easiest option of the three. It also outputs text values for each node. If you want this extra text, that’s great. However, if you don’t want it, it just clutters up the Notebook. The image that’s returned is also rather small, which makes it hard to read.

tree.plot_tree(tree_clf,
               feature_names=X_train.columns.tolist(),
               class_names = y_train.unique().tolist(),
               rounded=True,
               filled=True)

The Proposed Solution

The export_graphviz package is necessary for creating a visualization that can be customized. As such, for visualizing within the Notebook, my preference would be to use export_graphviz and graphviz.Source.

However, when it comes to saving the image, I’m not a big fan of switching languages within the same Notebook, i.e. using the command line to convert the .dot to .png. As such, my preference would be to use pydotplus to convert the image within the Notebook, then save it using graph.write_png.

Up Next

Overfitting, Underfitting, and Data Sensitivity

Resources