cutcutcodec.core.nn.model.compression.img_cgavaenn.Discriminator
- class cutcutcodec.core.nn.model.compression.img_cgavaenn.Discriminator[source]
Classify the real and generated images.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(img: Tensor) Tensor[source]
Find if the image if a fake or a real image.
Parameters
- imgtorch.Tensor
The float image batch of shape (n, 3, h, w). With h and w >= 192 + k*32, k positive integer.
Returns
- is_faketorch.Tensor
A scalar in [0, 1], 0 if the image is real, 1 if it is a fake. New shape is (n, 1, (h-160)/32, (w-160)/32)
Examples
>>> import torch >>> from cutcutcodec.core.nn.model.compression.img_cgavaenn import Discriminator >>> discriminator = Discriminator() >>> discriminator(torch.rand((10, 3, 192, 192+2*32))).shape torch.Size([10, 1, 1, 3]) >>>