[ML] Conditional GAN

Artiper
|2023. 12. 18. 18:16

Related Work

많은 수의 출력 카테고리를 수용할 수 있도록 확장하는 것은 어려운 문제며, 두 번째 문제는 입력과 출력의 일대일 매핑을 학습하는데 초점을 맞췄다는 것이다. 예를 들어 이미지 라벨링의 경우 주어진 이미지에 적절하게 적용될 수 있는 다양한 태그가 있을 수 있으며, 서로 다른 annoatator가 동일한 이미지를 설명하기 위해 서로 다른 용어를 사용할 수 있다.

 

첫 번째 문제를 해결하는 데 도움이 되는 한 가지 방법은 다른 모달리티의 추가 정보를 활용하는 것이다. 예를 들어, 자연어 말뭉치를 사용하여 기하학적 관계가 연관있는 label에 대한 vector representation을 학습하는 것이다.

 

이러한 space에서 예측할 때, 예측 오류가 발생해도 여전히 진실에 근접하는 경우가 많다는 사실(예. 의자 대신 테이블을 예측)과 모델 학습 습 동안 보지못했던 레이블에 대해 자연스럽게 예측 일반화를 할 수 있다는 사실에서 이점을 얻을 수 있다.

 

[3]과 같은 연구에서는 이미지의 feature 공간과 word의 representation sapce의 간단한 선형 매핑으로도 classification 성능이 향상될 수 있음을 보여준다.

 

두 번째 문제를 해결하는 방법 중 하나는 조건부 확률적 생성 모델을 사용하는 것으로, 입력을 조건부 변수로 간주하고 one-to-many 매핑을 조건부 예측 분포로 인스턴스화 하는 것이다. [12]에서는 다중 모드 신경 언어 모델을 훈련하는 방법을 보여 주며, 이를 통해 이미지에 대한 설명 문장을 생성할 수 있다.

 

 

Conditional Adversarial Nets

Generative Adversarial Nets

GAN은 데이터 분포를 포착하는 생성 모델 G와 G가 아닌 훈련 데이터에서 샘플이 나올 확률을 추정하는 Discriminative model D로 구성된다. G와 D 모두 다층 퍼셉트론과 같은 비선형 매핑 함수일 수 있다.

 

데이터 $x$에 대한 Generator 분포 $p_g$를 학습하기 위해, Generator는 사전 노이즈 분포 $P_z(z)$를 $G(z; \theta_{g})$ 데이터 공간으로 매핑하는 함수를 구축한다. 그리고 판별자 $D(x; \theta_{d})$는 $x$가 $p_g$가 아닌 학습 데이터에서 나왔을 확률을 표현하 단일 스칼라를 출력한다.

 

 

Conditional Adversarial Nets

GAN은 생성기, 판별기 둘 다 추가 정보 y를 조건으로 하는 경우 conditional model로 확장할 수 있다. $y$는 class label 또는 다른 양식의 데이터 등 모든 종류의 보조 정보가 될 수 있다. 해당 논문의 저자들은 $y$를 판별기와 생성기 모두에게 추가 입력 레이어로 $y$를 공급하여 컨디셔닝을 수행한다.

 

제너레이터에서 prior 입력 노이즈 $p_z(z)$ 및 $y$는 공동  hidden representation으로 결합되며, GAN 학습 프레임워크는 이 hidden representation의 구성 방식에 상당한 유연성을 부여한다.

 

판별기에서 $x$와 $y$는 입력으로 존재하며 판별 함수(이 경우 MLP로 다시 구현됨)로 표시된다.

 

2인 minmax 게임의 목적 함수는 다음과 같다.

 

 

그림 1은 간단한 조건부 적대적 네트워크의 구조를 보여준다.

 

 

GAN의 loss 이해

여기서 log는 자연로그이다. 

 

판별자 관점

실제 이미지에 대해 1에 가까운 값을, 생성된 이미지에 대해서는 0에 가까운 값을 출력하도록 학습해야 한다. 이진 분류 문제 (실제이미지인지, 생성된 이미지인지) 에선 BCE를 loss함수로 사용한다. 

 

 

Experimental Results

Unimodal

원-핫 벡터로 인코딩된 클래스 레이블에 따라 조건이 주어진 MNIST 이미지에 대해 conditional GAN을 훈련했다. 차원이 100으로 이루어진 노이즈 prior z는 단위 hypercube 내에 균등 분포에서 도출되었다. z와 y 모두  ReLU를 사용하여 레이어 크기가 각각 200과 1000으로 구성된 hidden layer로 매핑된 후 차원이 1200인 두 번째 결합된 hidden ReLU Layer에 매핑된다.

 

판별기의 정확한 아키텍처는 충분한 성능만 있다면 중요하지 않으며, 일반적으로 maxout 단위가 작업에 적합하다는 것이 논문에서 나온 의견이다.

 

100개의 미니 배치, 0.1의 초기 learning rate에서 0.000001까지 기하급수적으로 감소한 decay 계수 1.00004의 SGD를 사용하여 학습. 모멘텀은 초기 값이 0.5에서 0.7까지 증가된 값을 사용한다. 확률 0.5의 드롭 아웃이 생성기와 판별기 모두에 적용된다. 그리고 validation set에 대한 log-likelihood의 추정치가 중단점으로 사용되었다.

 

Multimodal

Flickr와 같은 사진 사이트는 사용자 태그의 형태로 라벨이 지정된 데이터의 풍부한 소스이다. 

 

사용자 생성 메타데이터(UGM)는 일반적으로 더 canonical(설명적)이며, 이미지에 존재하는 객체를 식별하기보다는 인간이 자연어로 이미지를 설명하는 방식에 훨씬 더 가깝다는 점에서 표준 이미지 라벨링 체계와 다르다. UGM의 또 다른 측면은 동어반복이 만연하고 사용자마다 동일한 개념을 설명할 때 서로 다른 어휘를 사용할 수 있기 때문에 이러한 라벨을 효율적으로 정규화하는 방법이 중요해진다. 개념적 단어 임베딩은 관련 개념이 결국 유사한 벡터로 표현되기 때문에 매우 유용할 수 있다.

 

이미지 특징은 ImageNet 데이터 세트를 사용해 4096개의 fully connected layer의 출력을 image representation으로 사용한다.

 

단어 표현을 위하여 YFCC100M 2 데이터 셋 메타데이터에서 사용자 태그, 제목, 설명을 연결하여 텍스트의 말뭉치를 수집한다. 텍스트를 전처리하고 정리한 후, 단어 크기가 200인 스킵그램 모델을 훈련시켰다. 그리고 어휘에서 200번 미만으로 나타나는 단어는 생략하여 247465 크기의 dictionary를 만들었다.

 

GAN을 훈련시키는 동안 conovlutional 모델과 language 모델은 고정된 상태로 유지한다. 그리고 이러한 모델을 통해 backpropagation을 하는 실험은 향후 작업으로 남겨둔다. 

 

평가를 위해 각 이미지에 대해 100개의 샘플을 생성하고 어휘에 포함된 단어의 벡터 표현의 코사인 유사도를 사용하여 각 샘플에 가장 가까운 상위 20개의 단어를 찾는다. 그런 다음 100개의 샘플 중에서는 가장 일반적인 상위 10개의 단어를 선택한다.

 

 

정리

  • 생성기와 판별기에 특정 condition을 나타내는 정보 y를 추가해주며, 이 y는 형태가 정해진 것은 아니라 다양한 형태를 가질 수 있다.
  • 판별기는 라벨 정보가 주어지면 가짜 샘플과 진짜 샘플을 구별하는 방법을 학습한다.
  • 판별기에는 레이블이 포함된 실제 데이터, 가짜 데이터가 모두 제공되며, 실제 데이터와 가짜 데이터를 인식하는 것뿐만 아니라, 일치하는 쌍을 찾아내는 방법도 학습한다.
  • latent vector와 label embedding을 합친 것을 joint representation이라고 칭할 수 있다.

 

판별기는 최종적으로 입력이 진짜인지 가짜인지를 나타내는 확률을 출력한다. 학습의 목표는 다음과 같다.

  • 모든 실제 샘플과 라벨의 pair를 accept한다.
  • 모든 가짜 샘플과 라벨의 pair를 reject한다. (샘플은 label과 일치하지만, 거절하는 듯 하다.)
  • 또한 해당 라벨이 일치하지 않는 경우 모든 fake sample을 reject한다.

 

예시는 다음과 같다.

  • 생성된 이미지가 1이지만 라벨이 2라면 예시가 진짜인지 가짜인지에 관계 없이 reject한다는 것 같다. 주어진 라벨과 매치되지 않기 때문에.
  • 심지어 이미지와 label이 일치하는 경우라도 fake로 판별한다. 

 

 

 

왜 가능한걸까

  • 이미지를 생성할 때 noise와 label값을 랜덤으로 받는데, label값은 임베딩으로 확실하게 나타낼 수 있는 feature이고, noise는 불확실한 특성이다.
  • 이로 인해 noise값이 어떻게 들어오든, 확실한 지표인 label값이 있으니, 해당 label값이 들어오면 어떤 랜덤z가 같이 concat 되더라도 discriminator가 1에 가까운 출력을 내도록 하여 label에 맞는 이미지를 생성할 수 있도록 할 것이다.
  • discriminator는 진짜 이미지를 판별하는 loss값과, 가짜 이미지를 판별하는 loss값을 둘 다 합산한 뒤 1/2를 곱하는 연산을 수행해서 해당 값을 loss로 사용한다.
  • 이를 통해 discriminaotr는 더욱 더 정교해지며, 정교해지는 Discriminator를 속이기 위해 Generator도 label에 맞는 이미지를 생성하며 학습이 진행될 것이다.

 

구현

import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision import utils

from torch.utils.data import DataLoader
from torch import optim

import numpy as np
import matplotlib.pyplot as plt



class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_classes = 10
        self.nz = 100
        self.input_size = (1, 28, 28)
        
        # Noise와 label을 결합하는 용도인 label embedding matrix를 생성
        # 해당 embedding 값 또한 학습 가능한 파라미터임에 유의.
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes) # num embedding, embedding_dim
        
        # Generator
        self.gen = nn.Sequential(
            nn.Linear(self.nz + self.num_classes, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, int(np.prod(self.input_size))), # 1024, 784
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        # noise와 label의 결합     
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        x = self.gen(gen_input)
        x = x.view(x.size(0), *self.input_size)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_size = (1, 28, 28)
        self.num_classes = 10
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)
        self.dis = nn.Sequential(
            nn.Linear(self.num_classes + int(np.prod(self.input_size)), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,512),
            nn.Dropout(0.4),
            
            nn.LeakyReLU(0.2),
            nn.Linear(512,512),
            nn.Dropout(0.4),
            
            nn.LeakyReLU(0.2),
            nn.Linear(512,1),
            nn.Sigmoid()
        )
        
    def forward(self, img, labels):
        dis_input = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), -1)
        x = self.dis(dis_input)
        return x
    

def initialize_weights(model):
    classname = model.__class__.__name__
    # fc layer
    if classname.find('Linear') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
        nn.init.constant_(model.bias.data, 0)
    # batchnorm
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)


def check_data(train_dataset, img_save_path):
    img, label = train_dataset.data, train_dataset.targets
    
    # Make it to 4D Tensor1
    # 기존 : (#Batch) x (height) x (width) -> (#Batch) x (#channel) x (height) x(width)
    if len(img.shape) == 3:
        img = img.unsqueeze(1)
    
    # Visualize
    img_grid = utils.make_grid(img[:40], ncol=8, padding=2)
    show(img_grid, img_save_path)


def show(img, path):
    img = img.numpy() # Tensor -> numpy array
    img = img.transpose([1,2,0]) # C x H x W -> H x W x C
    plt.imshow(img, interpolation='nearest')
    plt.savefig(path + "/test.png")


def show_generated_image(model_gen, path, device):
    model_gen.eval()
    
    # fake image 생성
    with torch.no_grad():
        fig = plt.figure(figsize=(8,8))
        cols, rows = 4, 4 # row와 col 갯수
        for i in range(rows * cols):
            fixed_noise = torch.randn(16, 100, device=device)
            label = torch.randint(0,10,(16,), device=device)
            img_fake = model_gen(fixed_noise, label).detach().cpu()
            fig.add_subplot(rows, cols, i+1)
            plt.title(label[i].item())
            plt.axis('off')
            plt.imshow(img_fake[i].squeeze(), cmap='gray')
    plt.savefig(path + "/generated_img.png")


def train(model_dis, model_gen, train_loader, device):
    loss_func = nn.BCELoss()
    
    lr = 2e-4
    beta1 = 0.5
    beta2 = 0.999
    
    opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1,beta2))
    opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1,beta2))
    
    nz = 100
    num_epochs = 100
    
    loss_history = {'gen': [], 'dis': []}
    
    batch_count = 0
    start_time = time.time()
    model_dis.train()
    model_gen.train()
    
    for epoch in range(num_epochs):
        for x_batch, y_batch in train_loader:
            batch_size = x_batch.shape[0]
            
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            y_batch_real = torch.Tensor(batch_size, 1).fill_(1.0).to(device)
            y_batch_fake = torch.Tensor(batch_size, 1).fill_(0.0).to(device)
            
            # Generator 학습 시작
            model_gen.zero_grad()
            noise = torch.randn(batch_size, 100).to(device)
            gen_label = torch.randint(0, 10, (batch_size, )).to(device)
            
            # 가짜 이미지 생성
            generated_img = model_gen(noise, gen_label)
            
            # 가짜 이미지 판별
            dis_result = model_dis(generated_img, gen_label)
            
            # discriminator가 1에 가까운 출력을 낼 수 있도록 generator를 학습.
            loss_gen = loss_func(dis_result, y_batch_real)
            loss_gen.backward()
            opt_gen.step()
            
            # Discriminator 학습 시작
            model_dis.zero_grad()
            
            # 진짜 이미지 판별
            dis_result = model_dis(x_batch, y_batch)
            loss_real = loss_func(dis_result, y_batch_real)
            
            # 가짜 이미지 판별
            # Discriminator가 가짜이미지로 분류한 값과, y_batch_fake의 값의 차이를 줄임으로
            # 가짜이미지를 가짜이미지로 분류할 수 있는 성능을 올림
            out_dis = model_dis(generated_img.detach(), gen_label)
            loss_fake = loss_func(out_dis, y_batch_fake)
            
            # 진짜 이미지 판별 loss와 가짜 이미지 판별 loss를 더한 뒤 2를 나누어 loss값을 사용한다. (GAN loss를 구현할 떄는 이와 같은 방식을 따름)
            loss_dis = (loss_real + loss_fake) / 2
            loss_dis.backward()
            opt_dis.step()

            loss_history['gen'].append(loss_gen.item())
            loss_history['dis'].append(loss_dis.item())
            
            batch_count += 1
            if batch_count % 1000 == 0:
                print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, loss_gen.item(), loss_dis.item(), (time.time()-start_time)/60))
                
    return model_dis, model_gen
            
    


def main():
    device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
    img_save_path = r"/workspace/Model_Implementation/GenerativeModel/gan/results"
    
    # Set Data path
    datapath = './data'
    os.makedirs(datapath, exist_ok=True)

    # Pre-process
    trans = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Laod MNIST
    train_dataset = datasets.MNIST(datapath, train=True, download=True, transform=trans)
    
    # Check data
    check_data(train_dataset, img_save_path)
    
    # DatLoader
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    print(len(train_loader))
    
    model_gen = Generator().to(device)
    model_dis = Discriminator().to(device)
    
    # Apply weight initialization
    model_gen.apply(initialize_weights);
    model_dis.apply(initialize_weights);
    
    model_dis, model_gen = train(model_dis, model_gen, train_loader, device)
    show_generated_image(model_gen, img_save_path, device)
    

if __name__=='__main__':
    main()

 

 

 

 

 

 

https://learnopencv.com/conditional-gan-cgan-in-pytorch-and-tensorflow/

https://velog.io/@tobigs16gm/GAN-DCGAN

https://pseudo-lab.github.io/Tutorial-Book/chapters/GAN/Ch1-Introduction.html

https://arxiv.org/abs/1411.1784

https://velog.io/@wilko97/Conditional-GAN

'Machine Learning' 카테고리의 다른 글

[ML] Autoencoder / FashionMNIST 데이터 셋 사용  (0) 2023.11.01