2nd subter network arch

This commit is contained in:
Jan Kowalczyk
2025-06-17 07:26:03 +02:00
parent 9298dea329
commit bbd093da0c
9 changed files with 248 additions and 30 deletions

View File

@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_receptive_field
from base.base_net import BaseNet
@@ -29,6 +30,13 @@ class SubTer_LeNet(BaseNet):
x = self.fc1(x)
return x
def summary(self, receptive_field: bool = False):
# first run super method to log parameters and structure
super().summary(receptive_field=receptive_field)
self.logger.info("torch_receptive_field:")
torch_receptive_field.receptive_field(self, input_size=self.input_dim)
# torch_receptive_field.receptive_field_for_unit(rf, "2", (2,2))
class SubTer_LeNet_Decoder(BaseNet):
def __init__(self, rep_dim=1024):