This commit is contained in:
Jan Kowalczyk
2025-09-11 14:50:16 +02:00
parent 35766b9028
commit e4b298cf06
6 changed files with 362 additions and 26 deletions

View File

@@ -23,6 +23,8 @@ dependencies = [
"scipy>=1.16.0",
"seaborn>=0.13.2",
"six>=1.17.0",
"tabulate>=0.9.0",
"thop>=0.1.1.post2209072238",
"torch-receptive-field",
"torchscan>=0.1.1",
"visualtorch>=0.2.4",

View File

@@ -0,0 +1,101 @@
import torch
from thop import profile
from networks.subter_LeNet import SubTer_LeNet, SubTer_LeNet_Autoencoder
from networks.subter_LeNet_rf import SubTer_Efficient_AE, SubTer_EfficientEncoder
# Configuration
LATENT_DIMS = [32, 64, 128, 256, 512, 768, 1024]
BATCH_SIZE = 1
INPUT_SHAPE = (BATCH_SIZE, 1, 32, 2048)
def count_parameters(model, input_shape):
"""Count MACs and parameters for a model."""
model.eval()
with torch.no_grad():
input_tensor = torch.randn(input_shape)
macs, params = profile(model, inputs=(input_tensor,))
return {"MACs": macs, "Parameters": params}
def format_number(num: float) -> str:
"""Format large numbers with K, M, B, T suffixes."""
for unit in ["", "K", "M", "B", "T"]:
if abs(num) < 1000.0 or unit == "T":
return f"{num:3.2f}{unit}"
num /= 1000.0
def main():
# Collect results per latent dimension
results = {} # dim -> dict of 8 values
for dim in LATENT_DIMS:
# Instantiate models for this latent dim
lenet_enc = SubTer_LeNet(rep_dim=dim)
eff_enc = SubTer_EfficientEncoder(rep_dim=dim)
lenet_ae = SubTer_LeNet_Autoencoder(rep_dim=dim)
eff_ae = SubTer_Efficient_AE(rep_dim=dim)
# Profile each
lenet_enc_stats = count_parameters(lenet_enc, INPUT_SHAPE)
eff_enc_stats = count_parameters(eff_enc, INPUT_SHAPE)
lenet_ae_stats = count_parameters(lenet_ae, INPUT_SHAPE)
eff_ae_stats = count_parameters(eff_ae, INPUT_SHAPE)
results[dim] = {
"lenet_enc_params": format_number(lenet_enc_stats["Parameters"]),
"lenet_enc_macs": format_number(lenet_enc_stats["MACs"]),
"eff_enc_params": format_number(eff_enc_stats["Parameters"]),
"eff_enc_macs": format_number(eff_enc_stats["MACs"]),
"lenet_ae_params": format_number(lenet_ae_stats["Parameters"]),
"lenet_ae_macs": format_number(lenet_ae_stats["MACs"]),
"eff_ae_params": format_number(eff_ae_stats["Parameters"]),
"eff_ae_macs": format_number(eff_ae_stats["MACs"]),
}
# Build LaTeX table with tabularx
header = (
"\\begin{table}[!ht]\n"
"\\centering\n"
"\\renewcommand{\\arraystretch}{1.15}\n"
"\\begin{tabularx}{\\linewidth}{lXXXXXXXX}\n"
"\\hline\n"
" & \\multicolumn{4}{c}{\\textbf{Encoders}} & "
"\\multicolumn{4}{c}{\\textbf{Autoencoders}} \\\\\n"
"\\cline{2-9}\n"
"\\textbf{Latent $z$} & "
"\\textbf{LeNet Params} & \\textbf{LeNet MACs} & "
"\\textbf{Eff. Params} & \\textbf{Eff. MACs} & "
"\\textbf{LeNet Params} & \\textbf{LeNet MACs} & "
"\\textbf{Eff. Params} & \\textbf{Eff. MACs} \\\\\n"
"\\hline\n"
)
rows = []
for dim in LATENT_DIMS:
r = results[dim]
row = (
f"{dim} & "
f"{r['lenet_enc_params']} & {r['lenet_enc_macs']} & "
f"{r['eff_enc_params']} & {r['eff_enc_macs']} & "
f"{r['lenet_ae_params']} & {r['lenet_ae_macs']} & "
f"{r['eff_ae_params']} & {r['eff_ae_macs']} \\\\"
)
rows.append(row)
footer = (
"\\hline\n"
"\\end{tabularx}\n"
"\\caption{Parameter and MAC counts for SubTer variants across latent dimensionalities.}\n"
"\\label{tab:subter_counts}\n"
"\\end{table}\n"
)
latex_table = header + "\n".join(rows) + "\n" + footer
print(latex_table)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,155 @@
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
\begin{table}[!ht]
\centering
\renewcommand{\arraystretch}{1.15}
\begin{tabularx}{\linewidth}{lXXXXXXXX}
\hline
& \multicolumn{4}{c}{\textbf{Encoders}} & \multicolumn{4}{c}{\textbf{Autoencoders}} \\
\cline{2-9}
\textbf{Latent $z$} & \textbf{LeNet Params} & \textbf{LeNet MACs} & \textbf{Eff. Params} & \textbf{Eff. MACs} & \textbf{LeNet Params} & \textbf{LeNet MACs} & \textbf{Eff. Params} & \textbf{Eff. MACs} \\
\hline
32 & 525.29K & 27.92M & 263.80K & 29.82M & 1.05M & 54.95M & 532.35K & 168.49M \\
64 & 1.05M & 28.44M & 525.94K & 30.08M & 2.10M & 56.00M & 1.06M & 169.02M \\
128 & 2.10M & 29.49M & 1.05M & 30.61M & 4.20M & 58.10M & 2.11M & 170.07M \\
256 & 4.20M & 31.59M & 2.10M & 31.65M & 8.39M & 62.29M & 4.20M & 172.16M \\
512 & 8.39M & 35.78M & 4.20M & 33.75M & 16.78M & 70.68M & 8.40M & 176.36M \\
768 & 12.58M & 39.98M & 6.29M & 35.85M & 25.17M & 79.07M & 12.59M & 180.55M \\
1024 & 16.78M & 44.17M & 8.39M & 37.95M & 33.56M & 87.46M & 16.79M & 184.75M \\
\hline
\end{tabularx}
\caption{Parameter and MAC counts for SubTer variants across latent dimensionalities.}
\label{tab:subter_counts}
\end{table}

View File

@@ -141,6 +141,8 @@ dependencies = [
{ name = "scipy" },
{ name = "seaborn" },
{ name = "six" },
{ name = "tabulate" },
{ name = "thop" },
{ name = "torch-receptive-field" },
{ name = "torchscan" },
{ name = "visualtorch" },
@@ -166,6 +168,8 @@ requires-dist = [
{ name = "scipy", specifier = ">=1.16.0" },
{ name = "seaborn", specifier = ">=0.13.2" },
{ name = "six", specifier = ">=1.17.0" },
{ name = "tabulate", specifier = ">=0.9.0" },
{ name = "thop", specifier = ">=0.1.1.post2209072238" },
{ name = "torch-receptive-field", git = "https://github.com/Fangyh09/pytorch-receptive-field.git" },
{ name = "torchscan", specifier = ">=0.1.1" },
{ name = "visualtorch", specifier = ">=0.2.4" },
@@ -882,6 +886,26 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" },
]
[[package]]
name = "tabulate"
version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" },
]
[[package]]
name = "thop"
version = "0.1.1.post2209072238"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "torch" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/0f/72beeab4ff5221dc47127c80f8834b4bcd0cb36f6ba91c0b1d04a1233403/thop-0.1.1.post2209072238-py3-none-any.whl", hash = "sha256:01473c225231927d2ad718351f78ebf7cffe6af3bed464c4f1ba1ef0f7cdda27", size = 15443, upload-time = "2022-09-07T14:38:37.211Z" },
]
[[package]]
name = "threadpoolctl"
version = "3.6.0"