From cd6e046455c5927030f1bbbedc7d74b9cb2c6a4b Mon Sep 17 00:00:00 2001
From: sdtblck <sdtblck@gmail.com>
Date: Thu, 18 Jun 2020 14:58:26 +0100
Subject: [PATCH] finishing touches to zhao et al augs

---
 training/loss.py          | 11 +++++++----
 training/misc.py          | 21 ++++++++++++++++++++-
 training/training_loop.py | 18 ++++++++++++++----
 3 files changed, 41 insertions(+), 9 deletions(-)

diff --git a/training/loss.py b/training/loss.py
index db8b8b2..1883e62 100755
--- a/training/loss.py
+++ b/training/loss.py
@@ -208,10 +208,13 @@ def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_
             pl_latents = tf.random_normal([pl_minibatch] + G.input_shapes[0][1:])
             pl_labels = training_set.get_random_labels_tf(pl_minibatch)
             fake_images_out, fake_dlatents_out = G.get_output_for(pl_latents, pl_labels, is_training=True, return_dlatents=True)
-            if augment:
-                fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
-                fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
-                fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
+            # TODO: applying augmentations here fails with the following error:
+            # TypeError: Second-order gradient for while loops not supported.
+            # setting pl_minibatch_shrink to 1 would work - but will have a higher memory usage
+            # if augment:
+            #     fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
+            #     fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
+            #     fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
 
         # Compute |J*y|.
         pl_noise = tf.random_normal(tf.shape(fake_images_out)) / np.sqrt(np.prod(G.output_shape[2:]))
diff --git a/training/misc.py b/training/misc.py
index 9e29e7f..02dc5c7 100755
--- a/training/misc.py
+++ b/training/misc.py
@@ -95,6 +95,14 @@ def convert_to_pil_image(image, drange=[0, 1]):
 def save_image_grid(images, filename, drange=[0, 1], grid_size=None):
     convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename)
 
+# ----------------------------------------------------------------------------
+# Image Augmentations.
+
+alpha_override = float(os.environ.get('SPATIAL_AUGS_ALPHA', '0'))
+if alpha_override >= 1:
+  alpha_override = 0.999
+elif alpha_override == 0.0:
+  alpha_override = 0
 
 def apply_mirror_augment(minibatch):
     mask = np.random.rand(minibatch.shape[0]) < 0.5
@@ -142,6 +150,8 @@ def zoom_in(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
+    if alpha_override > 0:
+      alpha = alpha_override
     n = tf.random_uniform(shape=[], minval=1 - alpha, maxval=1, dtype=tf.float32, seed=seed, name=None)
     shape = tf.shape(tf_img)
     h = shape[0]
@@ -181,7 +191,8 @@ def zoom_out(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
-
+    if alpha_override > 0:
+      alpha = alpha_override
     # Set params
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
 
@@ -230,6 +241,8 @@ def X_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
+    if alpha_override > 0:
+      alpha = alpha_override
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
 
     shape = tf.shape(tf_img)
@@ -261,6 +274,8 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
+    if alpha_override > 0:
+      alpha = alpha_override
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
     shape = tf.shape(tf_img)
     h = shape[0]
@@ -293,6 +308,8 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
     Returns:
       Image tensor with shape `target_image_shape`.
     """
+    if alpha_override > 0:
+      alpha = alpha_override
     n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
 
     shape = tf.shape(tf_img)
@@ -350,6 +367,8 @@ def random_cutout(tf_img, alpha=0.1, seed=None):
     Returns:
     Cutout Image tensor
     """
+    if alpha_override > 0:
+      alpha = alpha_override
 
     # get img shape
     shape = tf.shape(tf_img)
diff --git a/training/training_loop.py b/training/training_loop.py
index 23ee9de..941cf15 100755
--- a/training/training_loop.py
+++ b/training/training_loop.py
@@ -42,7 +42,6 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
             with tf.name_scope('ImageSummaries'), tf.device('/cpu:0'):
                 tf.summary.image("reals_pre-augment", pre)
                 tf.summary.image("reals_post-augment", post)
-
     with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
         s = tf.shape(x)
         y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])
@@ -303,10 +302,21 @@ def training_loop(
         G.setup_weight_histograms(); D.setup_weight_histograms()
     metrics = metric_base.MetricGroup(metric_arg_list)
 
-
-    print('Training for %d kimg...\n' % total_kimg)
     if spatial_augmentations:
-        print('Augmenting fakes and reals\n')
+        print('Augmenting fakes and reals')
+        alpha_override = float(os.environ.get('SPATIAL_AUGS_ALPHA', '0'))
+        if alpha_override == 0.0:
+          print('Augmentation alpha at default setting of 0.1 - change by setting SPATIAL_AUGS_ALPHA environment variable')
+        else:
+          if alpha_override >= 1:
+            alpha_override = 0.999
+          print(f'Augmentation alpha set to {alpha_override}')
+        if save_image_summaries:
+          print('Saving image summaries to tensorboard')
+    print('Training for %d kimg...\n' % total_kimg)
+
+
+
     dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
     maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
     cur_nimg = int(resume_kimg * 1000)
-- 
GitLab