From e84808bb6fd64ed03afcefeecf72b4d652f7d0c8 Mon Sep 17 00:00:00 2001
From: sdtblck <sdtblck@gmail.com>
Date: Thu, 18 Jun 2020 03:13:55 +0100
Subject: [PATCH] -

---
 training/loss.py          |  39 ++++++++++++++
 training/misc.py          | 111 ++++++++++++++++++++------------------
 training/training_loop.py |  37 ++++---------
 3 files changed, 108 insertions(+), 79 deletions(-)

diff --git a/training/loss.py b/training/loss.py
index 7ad2fe1..b2faee3 100755
--- a/training/loss.py
+++ b/training/loss.py
@@ -10,16 +10,23 @@ import numpy as np
 import tensorflow as tf
 import dnnlib.tflib as tflib
 from dnnlib.tflib.autosummary import autosummary
+from training import misc
 
 #----------------------------------------------------------------------------
 # Logistic loss from the paper
 # "Generative Adversarial Nets", Goodfellow et al. 2014
 
+augment = True
+
 def G_logistic(G, D, opt, training_set, minibatch_size):
     _ = opt
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     labels = training_set.get_random_labels_tf(minibatch_size)
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     loss = -tf.nn.softplus(fake_scores_out) # log(1-sigmoid(fake_scores_out)) # pylint: disable=invalid-unary-operand-type
     return loss, None
@@ -29,6 +36,10 @@ def G_logistic_ns(G, D, opt, training_set, minibatch_size):
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     labels = training_set.get_random_labels_tf(minibatch_size)
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out))
     return loss, None
@@ -37,6 +48,10 @@ def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels):
     _ = opt, training_set
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     real_scores_out = D.get_output_for(reals, labels, is_training=True)
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     real_scores_out = autosummary('Loss/scores/real', real_scores_out)
@@ -53,6 +68,10 @@ def D_logistic_r1(G, D, opt, training_set, minibatch_size, reals, labels, gamma=
     _ = opt, training_set
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     real_scores_out = D.get_output_for(reals, labels, is_training=True)
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     real_scores_out = autosummary('Loss/scores/real', real_scores_out)
@@ -71,6 +90,10 @@ def D_logistic_r2(G, D, opt, training_set, minibatch_size, reals, labels, gamma=
     _ = opt, training_set
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     real_scores_out = D.get_output_for(reals, labels, is_training=True)
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     real_scores_out = autosummary('Loss/scores/real', real_scores_out)
@@ -94,6 +117,10 @@ def G_wgan(G, D, opt, training_set, minibatch_size):
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     labels = training_set.get_random_labels_tf(minibatch_size)
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     loss = -fake_scores_out
     return loss, None
@@ -102,6 +129,10 @@ def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, wgan_epsilon=
     _ = opt, training_set
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     real_scores_out = D.get_output_for(reals, labels, is_training=True)
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     real_scores_out = autosummary('Loss/scores/real', real_scores_out)
@@ -120,6 +151,10 @@ def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, wgan_lambd
     _ = opt, training_set
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     fake_images_out = G.get_output_for(latents, labels, is_training=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     real_scores_out = D.get_output_for(reals, labels, is_training=True)
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     real_scores_out = autosummary('Loss/scores/real', real_scores_out)
@@ -150,6 +185,10 @@ def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_
     latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
     labels = training_set.get_random_labels_tf(minibatch_size)
     fake_images_out, fake_dlatents_out = G.get_output_for(latents, labels, is_training=True, return_dlatents=True)
+    if augment:
+      fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+      fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+      fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
     fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
     loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out))
 
diff --git a/training/misc.py b/training/misc.py
index f293d39..6e7804b 100755
--- a/training/misc.py
+++ b/training/misc.py
@@ -109,18 +109,29 @@ def apply_mirror_augment_v(minibatch):
     minibatch[mask] = minibatch[mask, :, ::-1, :]
     return minibatch
 
-
-def rand_crop(image, crop_h, crop_w):
+def apply_random_aug(x, seed=None):
+    with tf.name_scope('SpatialAugmentations'):
+        choice = tf.random_uniform([], 0, 6, tf.int32, seed=seed)
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(0))), lambda: zoom_in(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(1))), lambda: zoom_out(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(2))), lambda: X_translate(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(3))), lambda: Y_translate(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(4))), lambda: XY_translate(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(5))), lambda: random_cutout(x, seed=seed), lambda: tf.identity(x))
+        return x
+
+
+def rand_crop(image, crop_h, crop_w, seed=None):
     shape = tf.shape(image)
     h, w = shape[0], shape[1]
-    begin = [h - crop_h, w - crop_w] * tf.random.uniform([2], 0, 1)
-    begin = tf.cast(begin, tf.int64)
+    begin = [h - crop_h, w - crop_w] * tf.random.uniform([2], 0, 1, seed=seed)
+    begin = tf.cast(begin, tf.int32)
     begin = tf.concat([begin, [0]], axis=0)  # Add channel dimension.
     image = tf.slice(image, begin, [crop_h, crop_w, 3])
     return image
 
 
-def zoom_in(tf_img, alpha=0.8, target_image_shape=None, seed=None):
+def zoom_in(tf_img, alpha=0.3, target_image_shape=None, seed=None):
     """
     Random zoom in to TF image
     Args:
@@ -130,13 +141,10 @@ def zoom_in(tf_img, alpha=0.8, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
-    print('zooming in')
-
     n = tf.random_uniform(shape=[], minval=1 - alpha, maxval=1, dtype=tf.float32, seed=seed, name=None)
     shape = tf.shape(tf_img)
     h = shape[0]
     w = shape[1]
-    c = shape[2]
     h_t = tf.cast(
         h, dtype=tf.float32, name='height')
     w_t = tf.cast(
@@ -151,7 +159,7 @@ def zoom_in(tf_img, alpha=0.8, target_image_shape=None, seed=None):
         rnd_h, dtype=tf.int32, name='height')
     rnd_w = tf.cast(
         rnd_w, dtype=tf.int32, name='width')
-    cropped_img = rand_crop(tf_img, rnd_h, rnd_w)
+    cropped_img = rand_crop(tf_img, rnd_h, rnd_w, seed=seed)
 
     # resize back to original size
     resized_img = tf.image.resize(
@@ -162,7 +170,7 @@ def zoom_in(tf_img, alpha=0.8, target_image_shape=None, seed=None):
     return resized_img
 
 
-def zoom_out(tf_img, alpha=0.8, target_image_shape=None, seed=None):
+def zoom_out(tf_img, alpha=0.3, target_image_shape=None, seed=None):
     """
     Random zoom out of TF image
     Args:
@@ -174,14 +182,11 @@ def zoom_out(tf_img, alpha=0.8, target_image_shape=None, seed=None):
     """
 
     # Set params
-    print('zooming out')
-
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
 
     shape = tf.shape(tf_img)
     h = shape[0]
     w = shape[1]
-    c = shape[2]
 
     if target_image_shape is None:
         target_image_shape = (h, w)
@@ -203,7 +208,7 @@ def zoom_out(tf_img, alpha=0.8, target_image_shape=None, seed=None):
         rnd_h, dtype=tf.int32, name='height')
     rnd_w = tf.cast(
         rnd_w, dtype=tf.int32, name='width')
-    cropped_img = rand_crop(padded_img, rnd_h, rnd_w)
+    cropped_img = rand_crop(padded_img, rnd_h, rnd_w, seed=seed)
 
     # Resize back to original size
     resized_img = tf.image.resize(
@@ -214,7 +219,7 @@ def zoom_out(tf_img, alpha=0.8, target_image_shape=None, seed=None):
     return resized_img
 
 
-def X_translate(tf_img, alpha=0.8, target_image_shape=None, seed=None):
+def X_translate(tf_img, alpha=0.3, target_image_shape=None, seed=None):
     """
     Random X translation within TF image with reflection padding
     Args:
@@ -224,14 +229,11 @@ def X_translate(tf_img, alpha=0.8, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
-    print('X translate')
-
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
 
     shape = tf.shape(tf_img)
     h = shape[0]
     w = shape[1]
-    c = shape[2]
 
     if target_image_shape is None:
         target_image_shape = (h, w)
@@ -244,11 +246,11 @@ def X_translate(tf_img, alpha=0.8, target_image_shape=None, seed=None):
     padded_img = tf.pad(tf_img, paddings, 'REFLECT')
 
     # Random crop section at original size
-    X_trans = rand_crop(padded_img, h, w)
+    X_trans = rand_crop(padded_img, target_image_shape[0], target_image_shape[1], seed=seed)
     return X_trans
 
 
-def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
+def XY_translate(tf_img, alpha=0.3, target_image_shape=None, seed=None):
     """
     Random XY translation within TF image with reflection padding
     Args:
@@ -258,13 +260,10 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
-    print('XY translate')
-
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
     shape = tf.shape(tf_img)
     h = shape[0]
     w = shape[1]
-    c = shape[2]
     if target_image_shape is None:
         target_image_shape = (h, w)
 
@@ -279,12 +278,11 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     padded_img = tf.pad(tf_img, paddings, 'REFLECT')
 
     # Random crop section at original size
-    XY_trans = rand_crop(padded_img, h, w)
-
+    XY_trans = rand_crop(padded_img, target_image_shape[0], target_image_shape[1], seed=seed)
     return XY_trans
 
 
-def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
+def Y_translate(tf_img, alpha=0.3, target_image_shape=None, seed=None):
     """
     Random Y translation within TF image with reflection padding
     Args:
@@ -294,14 +292,11 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
-    print('Y translate')
-
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
 
     shape = tf.shape(tf_img)
     h = shape[0]
     w = shape[1]
-    c = shape[2]
 
     if target_image_shape is None:
         target_image_shape = (h, w)
@@ -314,16 +309,37 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     padded_img = tf.pad(tf_img, paddings, 'REFLECT')
 
     # Random crop section at original size
-    Y_trans = rand_crop(padded_img, h, w)
+    Y_trans = rand_crop(padded_img, target_image_shape[0], target_image_shape[1], seed=seed)
     return Y_trans
 
-
-@tf.function
-def pad_fn(erase_area, y, x, h, w):
-    return tf.image.pad_to_bounding_box(erase_area, y, x, h, w)
-
-
-def random_cutout(tf_img, alpha=0.1, seed=None):
+def _pad_to_bounding_box(image, offset_height, offset_width, target_height,
+                        target_width):
+    """Pad `image` with zeros to the specified `height` and `width`.
+    Adds `offset_height` rows of zeros on top, `offset_width` columns of
+    zeros on the left, and then pads the image on the bottom and right
+    with zeros until it has dimensions `target_height`, `target_width`.
+    This op does nothing if `offset_*` is zero and the image already has size
+    `target_height` by `target_width`.
+    Args:
+    image: 3-D Tensor of shape `[height, width, channels]`
+    offset_height: Number of rows of zeros to add on top.
+    offset_width: Number of columns of zeros to add on the left.
+    target_height: Height of output image.
+    target_width: Width of output image.
+    Returns:
+    3-D float Tensor of shape
+    `[target_height, target_width, channels]`
+    """
+    shape = tf.shape(image)
+    height = shape[0]
+    width = shape[1]
+    after_padding_width = target_width - offset_width - width
+    after_padding_height = target_height - offset_height - height
+    # Do not pad on the depth dimension.
+    paddings = tf.reshape(tf.stack([offset_height, after_padding_height, offset_width, after_padding_width, 0, 0]), [3, 2])
+    return tf.pad(image, paddings)
+
+def random_cutout(tf_img, alpha=0.3, seed=None):
     """
     Cuts random black square out from TF image
     Args:
@@ -338,17 +354,11 @@ def random_cutout(tf_img, alpha=0.1, seed=None):
     shape = tf.shape(tf_img)
     h = shape[0]
     w = shape[1]
-    c = shape[2]
-    print('SHAPE: ')
-    print(shape)
-    # get square of random shape less than w*a, h*a
-    # max_val = tf.cast(tf.minimum(alpha * tf.cast(w, dtype=tf.int32), alpha * tf.cast(w, dtype=tf.int32)), dtype = tf.int32)
-    print(h)
-    max_val = 100
-    size = tf.random_uniform(shape=[], minval=0, maxval=max_val, dtype=tf.int32, seed=seed, name=None)
 
-    print('SQUARE SIZE: ')
-    print(size)
+    # get square of random shape less than w*a, h*a
+    val = tf.cast(tf.minimum(h, w), dtype=tf.float32)
+    max_val = tf.cast((alpha*val), dtype=tf.int32)
+    size = tf.random_uniform(shape=[], minval=1, maxval=max_val, dtype=tf.int32, seed=seed, name=None)
 
     # get random xy location of square
     x_loc_upper_bound = w - size
@@ -357,13 +367,12 @@ def random_cutout(tf_img, alpha=0.1, seed=None):
     x = tf.random_uniform(shape=[], minval=0, maxval=x_loc_upper_bound, dtype=tf.int32, seed=seed, name=None)
     y = tf.random_uniform(shape=[], minval=0, maxval=y_loc_upper_bound, dtype=tf.int32, seed=seed, name=None)
 
-    erase_area = tf.ones([5, 5, 3], dtype=tf.float32)
-    print('ERASE AREA: ')
-    print(erase_area)
+    erase_area = tf.ones([size, size, 3], dtype=tf.float32)
+
     if erase_area.shape == (0, 0, 3):
         return tf_img
     else:
-        mask = 1.0 - pad_fn(erase_area, y, x, h, w)
+        mask = 1.0 - _pad_to_bounding_box(erase_area, y, x, h, w)
         erased_img = tf.multiply(tf_img, mask)
         return erased_img
 
diff --git a/training/training_loop.py b/training/training_loop.py
index 12f2691..2a35c63 100755
--- a/training/training_loop.py
+++ b/training/training_loop.py
@@ -19,32 +19,16 @@ from metrics import metric_base
 
 #----------------------------------------------------------------------------
 # Just-in-time processing of training images before feeding them to the networks.
-def _random_choice(inputs, n_samples=1):
-    """
-    With replacement.
-    Params:
-      inputs (Tensor): Shape [n_states, n_features]
-      n_samples (int): The number of random samples to take.
-    Returns:
-      sampled_inputs (Tensor): Shape [n_samples, n_features]
-    """
-    # (1, n_states) since multinomial requires 2D logits.
-    uniform_log_prob = tf.expand_dims(tf.zeros(tf.shape(inputs)[0]), 0)
-
-    ind = tf.multinomial(uniform_log_prob, n_samples)
-    ind = tf.squeeze(ind, 0, name="random_choice_ind")  # (n_samples,)
-
-    return tf.gather(inputs, ind, name="random_choice")
-
-def apply_random_aug(x):
+
+def apply_random_aug(x, seed=None):
     with tf.name_scope('SpatialAugmentations'):
-        choice = tf.random_uniform([], 0, 5, tf.int32)
-        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(0))), lambda: misc.zoom_in(x), lambda: tf.identity(x))
-        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(1))), lambda: misc.zoom_out(x), lambda: tf.identity(x))
-        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(2))), lambda: misc.X_translate(x), lambda: tf.identity(x))
-        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(3))), lambda: misc.Y_translate(x), lambda: tf.identity(x))
-        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(4))), lambda: misc.XY_translate(x), lambda: tf.identity(x))
-        # x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(5))), lambda: misc.random_cutout(x), lambda: tf.identity(x))
+        choice = tf.random_uniform([], 0, 6, tf.int32, seed=seed)
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(0))), lambda: misc.zoom_in(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(1))), lambda: misc.zoom_out(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(2))), lambda: misc.X_translate(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(3))), lambda: misc.Y_translate(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(4))), lambda: misc.XY_translate(x, seed=seed), lambda: tf.identity(x))
+        x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(5))), lambda: misc.random_cutout(x, seed=seed), lambda: tf.identity(x))
         return x
 
 def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net):
@@ -66,9 +50,6 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
             tf.summary.image("pre-augment", pre)
             tf.summary.image("post-augment", post)
 
-
-
-
     with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
         s = tf.shape(x)
         y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])
-- 
GitLab