diff --git a/training/loss.py b/training/loss.py index db8b8b242f7577bd2b707723c013e94e29fa2d73..1883e62d33d0341b3d5927735df00e7ed4f47e8b 100755 --- a/training/loss.py +++ b/training/loss.py @@ -208,10 +208,13 @@ def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_ pl_latents = tf.random_normal([pl_minibatch] + G.input_shapes[0][1:]) pl_labels = training_set.get_random_labels_tf(pl_minibatch) fake_images_out, fake_dlatents_out = G.get_output_for(pl_latents, pl_labels, is_training=True, return_dlatents=True) - if augment: - fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1]) - fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment) - fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2]) + # TODO: applying augmentations here fails with the following error: + # TypeError: Second-order gradient for while loops not supported. + # setting pl_minibatch_shrink to 1 would work - but will have a higher memory usage + # if augment: + # fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1]) + # fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment) + # fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2]) # Compute |J*y|. pl_noise = tf.random_normal(tf.shape(fake_images_out)) / np.sqrt(np.prod(G.output_shape[2:])) diff --git a/training/misc.py b/training/misc.py index 9e29e7f1c200b86dceddf3ad1b096ee2a337f422..02dc5c7563b41b1c962797119c95b50a987c74b3 100755 --- a/training/misc.py +++ b/training/misc.py @@ -95,6 +95,14 @@ def convert_to_pil_image(image, drange=[0, 1]): def save_image_grid(images, filename, drange=[0, 1], grid_size=None): convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) +# ---------------------------------------------------------------------------- +# Image Augmentations. + +alpha_override = float(os.environ.get('SPATIAL_AUGS_ALPHA', '0')) +if alpha_override >= 1: + alpha_override = 0.999 +elif alpha_override == 0.0: + alpha_override = 0 def apply_mirror_augment(minibatch): mask = np.random.rand(minibatch.shape[0]) < 0.5 @@ -142,6 +150,8 @@ def zoom_in(tf_img, alpha=0.1, target_image_shape=None, seed=None): Returns: Image tensor with shape `target_image_shape`. """ + if alpha_override > 0: + alpha = alpha_override n = tf.random_uniform(shape=[], minval=1 - alpha, maxval=1, dtype=tf.float32, seed=seed, name=None) shape = tf.shape(tf_img) h = shape[0] @@ -181,7 +191,8 @@ def zoom_out(tf_img, alpha=0.1, target_image_shape=None, seed=None): Returns: Image tensor with shape `target_image_shape`. """ - + if alpha_override > 0: + alpha = alpha_override # Set params n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) @@ -230,6 +241,8 @@ def X_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None): Returns: Image tensor with shape `target_image_shape`. """ + if alpha_override > 0: + alpha = alpha_override n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) shape = tf.shape(tf_img) @@ -261,6 +274,8 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None): Returns: Image tensor with shape `target_image_shape`. """ + if alpha_override > 0: + alpha = alpha_override n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) shape = tf.shape(tf_img) h = shape[0] @@ -293,6 +308,8 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None): Returns: Image tensor with shape `target_image_shape`. """ + if alpha_override > 0: + alpha = alpha_override n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) shape = tf.shape(tf_img) @@ -350,6 +367,8 @@ def random_cutout(tf_img, alpha=0.1, seed=None): Returns: Cutout Image tensor """ + if alpha_override > 0: + alpha = alpha_override # get img shape shape = tf.shape(tf_img) diff --git a/training/training_loop.py b/training/training_loop.py index 23ee9defd5e201a328e3d776505e8a06299652bb..941cf15c06aef72a8e18621cb32905c43b86ef0d 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -42,7 +42,6 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm with tf.name_scope('ImageSummaries'), tf.device('/cpu:0'): tf.summary.image("reals_pre-augment", pre) tf.summary.image("reals_post-augment", post) - with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. s = tf.shape(x) y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) @@ -303,10 +302,21 @@ def training_loop( G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) - - print('Training for %d kimg...\n' % total_kimg) if spatial_augmentations: - print('Augmenting fakes and reals\n') + print('Augmenting fakes and reals') + alpha_override = float(os.environ.get('SPATIAL_AUGS_ALPHA', '0')) + if alpha_override == 0.0: + print('Augmentation alpha at default setting of 0.1 - change by setting SPATIAL_AUGS_ALPHA environment variable') + else: + if alpha_override >= 1: + alpha_override = 0.999 + print(f'Augmentation alpha set to {alpha_override}') + if save_image_summaries: + print('Saving image summaries to tensorboard') + print('Training for %d kimg...\n' % total_kimg) + + + dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000)