cutcutcodec.core.nn.model.compression.img_cgavaenn.Discriminator

class cutcutcodec.core.nn.model.compression.img_cgavaenn.Discriminator[source]

Classify the real and generated images.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(img: Tensor) Tensor[source]

Find if the image if a fake or a real image.

Parameters

imgtorch.Tensor

The float image batch of shape (n, 3, h, w). With h and w >= 192 + k*32, k positive integer.

Returns

is_faketorch.Tensor

A scalar in [0, 1], 0 if the image is real, 1 if it is a fake. New shape is (n, 1, (h-160)/32, (w-160)/32)

Examples

>>> import torch
>>> from cutcutcodec.core.nn.model.compression.img_cgavaenn import Discriminator
>>> discriminator = Discriminator()
>>> discriminator(torch.rand((10, 3, 192, 192+2*32))).shape
torch.Size([10, 1, 1, 3])
>>>