Entry 45: Visualizing Decision Trees
A major benefit of tree-based models is how easy they are to visualize. This visualization aspect is also vital to discussing how trees work.
The notebook where I did my code for this entry can be found on my github page in the Entry 45 notebook.
The Problem
In order to take full advantage of the benefits of tree-based models, a way to visualize them is needed.
The Options
There are three main options for visualizing a Decision Tree:
export_graphviz
pydotplus
tree.plot_tree
The package and options you want depend on what you want to do with the visualization. If you need to save the image then both export_graphviz
and pydotplus
will work. To simply visualize the tree in the Jupyter Notebook three packages have options that’ll work.
export_graphviz
Based on the number of tutorials export_graphviz
appears to be the more popular option. Or it could just be the option that’s been around longer and thus more well known.
Example:
export_graphviz(tree_clf, out_file='images/graphviz_example.dot',
feature_names=X_train.columns.tolist(),
class_names = y_train.unique().tolist(),
rounded=True,
filled=True)
The out_file
parameter accepts the path/name of the file and will save the image as a .dot
file. The file can then be converted to a .png
using a command line function. Fortunately, command line functions can be run in Jupyter Notebooks, so you don’t have to switch back and forth.
!dot -Tpng image_name.dot -o image_name.png
If the out_file
parameter is set to None
it can be saved to a variable and passed to graphviz.Source
to be visualized directly to the Jupyter Notebook.
Example:
dot_data = tree.export_graphviz(tree_clf, out_file=None,
feature_names=X_train.columns.tolist(),
class_names = y_train.unique().tolist(),
rounded=True,
filled=True)
graphviz.Source(dot_data)
pydotplus
Okay, so you need export_graphviz
to create the .dot
file of the tree, but then you can use pydotplus
to create a variable that can be either visualized or saved. To visualize it directly to the Notebook, use the Image
function from the IPython.display
package. To save it the commands graph.write_dot
and graph.write_png
come in handy.
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
graph.write_png('images/pydotplus_example.png')
Seeing how this can be done with fewer functions using graphviz.Source
, pydotplus
feels a little extraneous.
tree.plot_tree
Using tree.plot_tree
from the sklearn
package is probably the easiest option of the three. It also outputs text values for each node. If you want this extra text, that’s great. However, if you don’t want it, it just clutters up the Notebook. The image that’s returned is also rather small, which makes it hard to read.
tree.plot_tree(tree_clf,
feature_names=X_train.columns.tolist(),
class_names = y_train.unique().tolist(),
rounded=True,
filled=True)
The Proposed Solution
The export_graphviz
package is necessary for creating a visualization that can be customized. As such, for visualizing within the Notebook, my preference would be to use export_graphviz
and graphviz.Source
.
However, when it comes to saving the image, I’m not a big fan of switching languages within the same Notebook, i.e. using the command line to convert the .dot
to .png
. As such, my preference would be to use pydotplus
to convert the image within the Notebook, then save it using graph.write_png
.
Up Next
Overfitting, Underfitting, and Data Sensitivity