GraphvizでMLPと通常の重回帰分析のイメージ図を作る
作例
GraphvizでMLPと通常重回帰分析のイメージ図(conceptual diagram あるいは schematic diagram かな)を作ってみた。 作例は以下の通り。
プログラミング環境は以下の通り。
Graphvizのインストール
まずはGraphvizのインストール。 brewとpipで2回行う。
brew install graphviz pip install graphviz
MLPのイメージ図作成プログラム
node の表示順を制御するため、gi.node('1','x[0]' ,pos='0,0!')
というように pos
を追記している。
from graphviz import Digraph g = Digraph(format='png') g.attr(rankdir='LR') g.attr(splines='false') g.attr(dpi='300') # Input subgraph gi=Digraph(name='cluster_i') gi.attr(label='Input') gi.attr(penwidth='0') gi.node('1','x[0]' ,pos='0,0!') gi.node('2','x[1]' ,pos='0,1!') gi.node('3','x[..]' ,pos='0,2!') gi.node('4','x[n]',pos='0,3!') # Hidden subgraph 1 gh1=Digraph(name='cluster_h1') gh1.attr(label='Hidden layer 1') gh1.attr(penwidth='0') gh1.node('5', 'h1[0]' ,pos='0,0!') gh1.node('6', 'h1[1]' ,pos='0,1!') gh1.node('7', 'h1[..]' ,pos='0,2!') gh1.node('8', 'h1[n]',pos='0,3!') # Hidden subgraph 2 gh2=Digraph(name='cluster_h2') gh2.attr(label='Hidden layer 2') gh2.attr(penwidth='0') gh2.node('9', 'h2[0]' ,pos='0,0!') gh2.node('10','h2[1]' ,pos='0,1!') gh2.node('11','h2[..]' ,pos='0,2!') gh2.node('12','h2[n]',pos='0,3!') # Output subgraph go=Digraph(name='cluster_o') go.attr(label='Output') go.attr(penwidth='0') go.node('13','y',pos='0,1.5!') g.subgraph(gi) g.subgraph(gh1) g.subgraph(gh2) g.subgraph(go) g.edge('1', '5') g.edge('1', '6') g.edge('1', '7') g.edge('1', '8') g.edge('2', '5') g.edge('2', '6') g.edge('2', '7') g.edge('2', '8') g.edge('3', '5') g.edge('3', '6') g.edge('3', '7') g.edge('3', '8') g.edge('4', '5') g.edge('4', '6') g.edge('4', '7') g.edge('4', '8') g.edge('5', '9') g.edge('5', '10') g.edge('5', '11') g.edge('5', '12') g.edge('6', '9') g.edge('6', '10') g.edge('6', '11') g.edge('6', '12') g.edge('7', '9') g.edge('7', '10') g.edge('7', '11') g.edge('7', '12') g.edge('8', '9') g.edge('8', '10') g.edge('8', '11') g.edge('8', '12') g.edge('9', '13') g.edge('10', '13') g.edge('11', '13') g.edge('12', '13') g.render('fig_3_mlp', view=True)
重回帰のイメージ図作成プログラム
edge(矢印線)に回帰係数を示す w[..]
を表示するため、g.edge('3', '5', label='w[..]')
というように label
を追記している。
from graphviz import Digraph g = Digraph(format='png') g.attr(rankdir='LR') g.attr(splines='false') g.attr(dpi='300') # Input subgraph gi=Digraph(name='cluster_i') gi.attr(label='Input') gi.attr(penwidth='0') gi.node('1','x[0]' ,pos='0,0!') gi.node('2','x[1]' ,pos='0,1!') gi.node('3','x[..]' ,pos='0,2!') gi.node('4','x[n]',pos='0,3!') # Output subgraph go=Digraph(name='cluster_o') go.attr(label='Output') go.attr(penwidth='0') go.node('5','y',pos='0,0!') g.subgraph(gi) g.subgraph(go) g.edge('1', '5', label='w[0]') g.edge('2', '5', label='w[1]') g.edge('3', '5', label='w[..]') g.edge('4', '5', label='w[n]') g.render('fig_3_reg', view=True)
以 上