Torchview를 사용한 모델 시각화

다음 공식 깃허브를 참고하여 ubuntu에서 사용할 torchview를 세팅하였다.

 

설치

설치 기준은 데비안 기반 환경이며. 여기서는 우분투를 사용했습니다. 윈도우에 설치하는 과정들은 다른 블로그를 참조해야할 것 같습니다. 윈도우 설치과정을 조금 살펴보았을 때 graphviz 패키지를 설치할 때 해당 문서에 따르면 graphviz-2.38.msi와 같은 파일을 통해 설치하는 것을 권장하는 것처럼 보였습니다. 이는 환경변수 설정때문 같습니다. 아마 graphviz만 제대로 설치된다면 윈도우에서 사용하는 것도 까다롭지 않을 수도 있을 것 같긴한데, 사용하지 않아서 잘모르며 우분투환경에서의 설치만 다룹니다.

 

graphviz 설치

apt-get install graphviz

위처럼 설치를 해야합니다.

 

만약 pip 또는 conda를 통해 graphviz를 설치하였다면, 다음과 같은 에러가 발생하며 torchview 패키지의 model_graph.visual_graph를 통해 jupyter notebook에 모델을 그릴 때 다음과 같이 오류가 생깁니다.

CalledProcessError: Command '[PosixPath('dot'), '-Kdot', '-Tpng', '-O', 'model_graph']' returned non-zero exit status 1.

 

torchview는 graphviz 패키지에 종속성이 있으며, graphviz의 문서를 살펴보면 패키지 설치가 완료된 후 시스템 환경 변수에 등록하는 과정이 필요하다고 합니다.

 

torchview 설치

pip install torchview

 

이로써 필요한 패키지 설치는 모두 완료되었습니다.

 

 

사용법

그래프를 그리고자 하는 모델의 코드: 

class UnetModel(nn.Module):
    def __init__(self,
                 in_channels=3,
                 model_channels=128,
                 out_channels=3,
                 num_res_blocks=2,
                 attention_resolutions=(8,16),
                 dropout=0,
                 channel_mult=(1,2,2,2),
                 conv_resample=True,
                 num_heads=4,
                 class_num=10
                ):
        super().__init__()
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
		
        ...
        # 각종 모듈 선언
        ...
        
    def forward(self, x, timesteps, c, mask):
        """
        Apply the model to an input batch.
        :param x: an [N x C x H x W] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param c: a 1-D batch of classes.
        :param mask: a 1-D batch of conditioned/unconditioned.
        :return: an [N x C x ...] Tensor of outputs.
        """
        hs = []
        # time step and class embedding
        t_emb = self.time_emb(timestep_embedding(timesteps, dim=self.model_channels))
        c_emb = self.class_emb(c)
        
        
        # down stage
        h = x
        for module in self.down_blocks:
            h = module(h, t_emb, c_emb, mask)
#             print(h.shape)
            hs.append(h)
        
        # middle stage
        h = self.middle_blocks(h, t_emb, c_emb, mask)
        
        # up stage
        for module in self.up_blocks:
#             print(h.shape, hs[-1].shape)
            cat_in = torch.cat([h, hs.pop()], dim=1)
            h = module(cat_in, t_emb, c_emb, mask)
        
        return self.out(h)

 

보통 복잡한 모델을 그래프로 그리고 싶어할텐데, 이와 같이 모듈의 forward 함수에 많은 매개변수가 들어가는 방식은 공식 문서에서 작성되어있는 방식으로 사용하게 되면 입력 매개변수가 맞지 않게 됩니다. 확인하고자 하는 모델의 forward 함수의 매개변수는 입력 데이터 x를 제외하고 timesteps, c, mask를 입력으로 받는데 아래와 같은 공식문서의 방식으로는 바로 사용하기가 어려울 수 있습니다,

model_graph = draw_graph(model, input_size=(batch_size, 128), device='meta')

 

저는 해당 모델을 감싸는 nn.Module을 상속받은 wrappermodel을 선언한 뒤 timesteps, c, mask를 입력할 수 있도록 사용하였습니다.

class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        batch_size = x.size(0)
        timesteps = torch.randint(0, 500, (batch_size,), dtype=torch.long).to(x.device)
        c = torch.randint(0, 2, (batch_size,), dtype=torch.long).to(x.device)
        mask = torch.zeros(batch_size).int().to(x.device)
        
        return self.model(x, timesteps, c, mask)
import graphviz
graphviz.set_jupyter_format('png')

batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UnetModel(
    in_channels=1,
    model_channels=96,
    out_channels=1,
    channel_mult=(1, 2, 2),
    attention_resolutions=[],
    class_num=2
)
model.to(device)

wrapped_model = ModelWrapper(model)
model_graph = draw_graph(
    wrapped_model,
    input_size=(batch_size, 1, 64, 64),
    device=device,
    save_graph=True,
    filename="model_graph",
    )

전 해당 그래프 내용이 png로 저장되었으면 좋겠어서, save_graph옵션을 추가하여 그렸습니다. directory 옵션을 사용하면 저장될 png파일의 경로도 지정가능합니다.

 

작성 코드 이미지

 

그래프 결과 이미지

 

 

참고 자료

'Python 관련 > Pytorch' 카테고리의 다른 글

RuntimeError: CUDA error: device-side assert triggered  (0) 2024.04.05
Albumentation Color jitter 사용 오류  (0) 2024.04.05
[Pytorch] 모델 저장 오류  (0) 2024.03.07
[Pytorch] Generator  (0) 2024.02.27
[Pytorch] Randomness 제어  (0) 2024.02.27