172 lines
3.8 KiB
Python
172 lines
3.8 KiB
Python
# subter_lenet_arch.py
|
||
# Requires running from inside the PlotNeuralNet repo, like: python3 ../subter_lenet_arch.py
|
||
import sys, argparse
|
||
|
||
sys.path.append("../") # import pycore from repo root
|
||
|
||
from pycore.tikzeng import *
|
||
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--rep_dim", type=int, default=1024, help="latent size for FC")
|
||
args = parser.parse_args()
|
||
REP = int(args.rep_dim)
|
||
|
||
# Visual scales so the huge width doesn't dominate the figure
|
||
H32, H16, H8, H1 = 26, 18, 12, 1
|
||
D2048, D1024, D512, D256, D128, D1 = 52, 36, 24, 12, 6, 1
|
||
W1, W4, W8, W16, W32 = 1, 2, 2, 4, 8
|
||
|
||
|
||
arch = [
|
||
to_head(".."),
|
||
to_cor(),
|
||
to_begin(),
|
||
# --------------------------- ENCODER ---------------------------
|
||
# Input 1×32×2048 (caption carries H×W; s_filer is numeric)
|
||
to_Conv(
|
||
"input",
|
||
zlabeloffset=0.2,
|
||
s_filer="{{2048×32}}",
|
||
n_filer=1,
|
||
offset="(0,0,0)",
|
||
to="(0,0,0)",
|
||
height=H32,
|
||
depth=D2048,
|
||
width=W1,
|
||
caption="Input",
|
||
),
|
||
# Conv1 (5x5, same): 1->8, 32×2048
|
||
to_Conv(
|
||
"dwconv1",
|
||
s_filer="",
|
||
n_filer=1,
|
||
offset="(2,0,0)",
|
||
to="(input-east)",
|
||
height=H32,
|
||
depth=D2048,
|
||
width=W1,
|
||
caption="",
|
||
),
|
||
to_Conv(
|
||
"dwconv2",
|
||
s_filer="{{2048×32}}",
|
||
zlabeloffset=0.15,
|
||
n_filer=16,
|
||
offset="(0,0,0)",
|
||
to="(dwconv1-east)",
|
||
height=H32,
|
||
depth=D2048,
|
||
width=W16,
|
||
caption="Conv1",
|
||
),
|
||
# Pool1 2×2: 32×2048 -> 16×1024
|
||
# to_connection("input", "conv1"),
|
||
to_Pool(
|
||
"pool1",
|
||
offset="(0,0,0)",
|
||
zlabeloffset=0.3,
|
||
s_filer="{{512×32}}",
|
||
to="(dwconv2-east)",
|
||
height=H32,
|
||
depth=D512,
|
||
width=W16,
|
||
caption="",
|
||
),
|
||
# Conv2 (5x5, same): 8->4, stays 16×1024
|
||
to_Conv(
|
||
"dwconv3",
|
||
s_filer="",
|
||
n_filer=1,
|
||
offset="(2,0,0)",
|
||
to="(pool1-east)",
|
||
height=H32,
|
||
depth=D512,
|
||
width=W1,
|
||
caption="",
|
||
),
|
||
to_Conv(
|
||
"dwconv4",
|
||
n_filer=32,
|
||
zlabeloffset=0.3,
|
||
s_filer="{{512×32}}",
|
||
offset="(0,0,0)",
|
||
to="(dwconv3-east)",
|
||
height=H32,
|
||
depth=D512,
|
||
width=W32,
|
||
caption="Conv2",
|
||
),
|
||
# Pool2 2×2: 16×1024 -> 8×512
|
||
# to_connection("pool1", "conv2"),
|
||
to_Pool(
|
||
"pool2",
|
||
offset="(0,0,0)",
|
||
zlabeloffset=0.45,
|
||
s_filer="{{256×16}}",
|
||
to="(dwconv4-east)",
|
||
height=H16,
|
||
depth=D256,
|
||
width=W32,
|
||
caption="",
|
||
),
|
||
to_Pool(
|
||
"pool3",
|
||
offset="(0,0,0)",
|
||
zlabeloffset=0.45,
|
||
s_filer="{{128×8}}",
|
||
to="(pool2-east)",
|
||
height=H8,
|
||
depth=D128,
|
||
width=W32,
|
||
caption="",
|
||
),
|
||
to_Conv(
|
||
"squeeze",
|
||
n_filer=8,
|
||
zlabeloffset=0.45,
|
||
s_filer="{{128×8}}",
|
||
offset="(1,0,0)",
|
||
to="(pool3-east)",
|
||
height=H8,
|
||
depth=D128,
|
||
width=W8,
|
||
caption="Squeeze",
|
||
),
|
||
# FC -> rep_dim (use numeric n_filer)
|
||
to_fc(
|
||
"fc1",
|
||
n_filer="{{8×128×8}}",
|
||
zlabeloffset=0.5,
|
||
offset="(2,-.5,0)",
|
||
to="(squeeze-east)",
|
||
height=H1,
|
||
depth=D512,
|
||
width=W1,
|
||
caption="FC",
|
||
captionshift=0,
|
||
),
|
||
# to_connection("pool2", "fc1"),
|
||
# --------------------------- LATENT ---------------------------
|
||
to_Conv(
|
||
"latent",
|
||
n_filer="",
|
||
s_filer="latent dim",
|
||
offset="(1.3,0.5,0)",
|
||
to="(fc1-east)",
|
||
height=H8 * 1.6,
|
||
depth=D1,
|
||
width=W1,
|
||
caption=f"Latent Space",
|
||
),
|
||
to_end(),
|
||
]
|
||
|
||
|
||
def main():
|
||
name = "subter_lenet_arch"
|
||
to_generate(arch, name + ".tex")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|