damyarou

python, GMT などのプログラム

GraphvizでMLPと通常の重回帰分析のイメージ図を作る

作例

GraphvizMLPと通常重回帰分析のイメージ図(conceptual diagram あるいは schematic diagram かな)を作ってみた。 作例は以下の通り。

f:id:damyarou:20200816080901p:plain:w600

f:id:damyarou:20200816080923p:plain:w300

プログラミング環境は以下の通り。

Graphvizのインストール

まずはGraphvizのインストール。 brewとpipで2回行う。

brew install graphviz

pip install graphviz

MLPのイメージ図作成プログラム

node の表示順を制御するため、gi.node('1','x[0]' ,pos='0,0!') というように pos を追記している。

from graphviz import Digraph

g = Digraph(format='png')
g.attr(rankdir='LR')
g.attr(splines='false')
g.attr(dpi='300')

# Input subgraph
gi=Digraph(name='cluster_i')
gi.attr(label='Input')
gi.attr(penwidth='0')
gi.node('1','x[0]'  ,pos='0,0!')
gi.node('2','x[1]'  ,pos='0,1!')
gi.node('3','x[..]' ,pos='0,2!')
gi.node('4','x[n]',pos='0,3!')
# Hidden subgraph 1
gh1=Digraph(name='cluster_h1')
gh1.attr(label='Hidden layer 1')
gh1.attr(penwidth='0')
gh1.node('5', 'h1[0]'  ,pos='0,0!')
gh1.node('6', 'h1[1]'  ,pos='0,1!')
gh1.node('7', 'h1[..]' ,pos='0,2!')
gh1.node('8', 'h1[n]',pos='0,3!')
# Hidden subgraph 2
gh2=Digraph(name='cluster_h2')
gh2.attr(label='Hidden layer 2')
gh2.attr(penwidth='0')
gh2.node('9', 'h2[0]'  ,pos='0,0!')
gh2.node('10','h2[1]'  ,pos='0,1!')
gh2.node('11','h2[..]' ,pos='0,2!')
gh2.node('12','h2[n]',pos='0,3!')
# Output subgraph
go=Digraph(name='cluster_o')
go.attr(label='Output')
go.attr(penwidth='0')
go.node('13','y',pos='0,1.5!')

g.subgraph(gi)
g.subgraph(gh1)
g.subgraph(gh2)
g.subgraph(go)

g.edge('1', '5')
g.edge('1', '6')
g.edge('1', '7')
g.edge('1', '8')
g.edge('2', '5')
g.edge('2', '6')
g.edge('2', '7')
g.edge('2', '8')
g.edge('3', '5')
g.edge('3', '6')
g.edge('3', '7')
g.edge('3', '8')
g.edge('4', '5')
g.edge('4', '6')
g.edge('4', '7')
g.edge('4', '8')

g.edge('5', '9')
g.edge('5', '10')
g.edge('5', '11')
g.edge('5', '12')
g.edge('6', '9')
g.edge('6', '10')
g.edge('6', '11')
g.edge('6', '12')
g.edge('7', '9')
g.edge('7', '10')
g.edge('7', '11')
g.edge('7', '12')
g.edge('8', '9')
g.edge('8', '10')
g.edge('8', '11')
g.edge('8', '12')

g.edge('9', '13')
g.edge('10', '13')
g.edge('11', '13')
g.edge('12', '13')


g.render('fig_3_mlp', view=True)

重回帰のイメージ図作成プログラム

edge(矢印線)に回帰係数を示す w[..] を表示するため、g.edge('3', '5', label='w[..]') というように label を追記している。

from graphviz import Digraph

g = Digraph(format='png')
g.attr(rankdir='LR')
g.attr(splines='false')
g.attr(dpi='300')

# Input subgraph
gi=Digraph(name='cluster_i')
gi.attr(label='Input')
gi.attr(penwidth='0')
gi.node('1','x[0]'  ,pos='0,0!')
gi.node('2','x[1]'  ,pos='0,1!')
gi.node('3','x[..]' ,pos='0,2!')
gi.node('4','x[n]',pos='0,3!')
# Output subgraph
go=Digraph(name='cluster_o')
go.attr(label='Output')
go.attr(penwidth='0')
go.node('5','y',pos='0,0!')

g.subgraph(gi)
g.subgraph(go)

g.edge('1', '5', label='w[0]')
g.edge('2', '5', label='w[1]')
g.edge('3', '5', label='w[..]')
g.edge('4', '5', label='w[n]')

g.render('fig_3_reg', view=True)

以 上