Skip to content
Snippets Groups Projects
Commit cd6e0464 authored by sdtblck's avatar sdtblck
Browse files

finishing touches to zhao et al augs

parent 1c81c269
Branches
No related tags found
No related merge requests found
...@@ -208,10 +208,13 @@ def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_ ...@@ -208,10 +208,13 @@ def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_
pl_latents = tf.random_normal([pl_minibatch] + G.input_shapes[0][1:]) pl_latents = tf.random_normal([pl_minibatch] + G.input_shapes[0][1:])
pl_labels = training_set.get_random_labels_tf(pl_minibatch) pl_labels = training_set.get_random_labels_tf(pl_minibatch)
fake_images_out, fake_dlatents_out = G.get_output_for(pl_latents, pl_labels, is_training=True, return_dlatents=True) fake_images_out, fake_dlatents_out = G.get_output_for(pl_latents, pl_labels, is_training=True, return_dlatents=True)
if augment: # TODO: applying augmentations here fails with the following error:
fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1]) # TypeError: Second-order gradient for while loops not supported.
fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment) # setting pl_minibatch_shrink to 1 would work - but will have a higher memory usage
fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2]) # if augment:
# fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
# fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
# fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
# Compute |J*y|. # Compute |J*y|.
pl_noise = tf.random_normal(tf.shape(fake_images_out)) / np.sqrt(np.prod(G.output_shape[2:])) pl_noise = tf.random_normal(tf.shape(fake_images_out)) / np.sqrt(np.prod(G.output_shape[2:]))
......
...@@ -95,6 +95,14 @@ def convert_to_pil_image(image, drange=[0, 1]): ...@@ -95,6 +95,14 @@ def convert_to_pil_image(image, drange=[0, 1]):
def save_image_grid(images, filename, drange=[0, 1], grid_size=None): def save_image_grid(images, filename, drange=[0, 1], grid_size=None):
convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename)
# ----------------------------------------------------------------------------
# Image Augmentations.
alpha_override = float(os.environ.get('SPATIAL_AUGS_ALPHA', '0'))
if alpha_override >= 1:
alpha_override = 0.999
elif alpha_override == 0.0:
alpha_override = 0
def apply_mirror_augment(minibatch): def apply_mirror_augment(minibatch):
mask = np.random.rand(minibatch.shape[0]) < 0.5 mask = np.random.rand(minibatch.shape[0]) < 0.5
...@@ -142,6 +150,8 @@ def zoom_in(tf_img, alpha=0.1, target_image_shape=None, seed=None): ...@@ -142,6 +150,8 @@ def zoom_in(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns: Returns:
Image tensor with shape `target_image_shape`. Image tensor with shape `target_image_shape`.
""" """
if alpha_override > 0:
alpha = alpha_override
n = tf.random_uniform(shape=[], minval=1 - alpha, maxval=1, dtype=tf.float32, seed=seed, name=None) n = tf.random_uniform(shape=[], minval=1 - alpha, maxval=1, dtype=tf.float32, seed=seed, name=None)
shape = tf.shape(tf_img) shape = tf.shape(tf_img)
h = shape[0] h = shape[0]
...@@ -181,7 +191,8 @@ def zoom_out(tf_img, alpha=0.1, target_image_shape=None, seed=None): ...@@ -181,7 +191,8 @@ def zoom_out(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns: Returns:
Image tensor with shape `target_image_shape`. Image tensor with shape `target_image_shape`.
""" """
if alpha_override > 0:
alpha = alpha_override
# Set params # Set params
n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
...@@ -230,6 +241,8 @@ def X_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None): ...@@ -230,6 +241,8 @@ def X_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns: Returns:
Image tensor with shape `target_image_shape`. Image tensor with shape `target_image_shape`.
""" """
if alpha_override > 0:
alpha = alpha_override
n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
shape = tf.shape(tf_img) shape = tf.shape(tf_img)
...@@ -261,6 +274,8 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None): ...@@ -261,6 +274,8 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns: Returns:
Image tensor with shape `target_image_shape`. Image tensor with shape `target_image_shape`.
""" """
if alpha_override > 0:
alpha = alpha_override
n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
shape = tf.shape(tf_img) shape = tf.shape(tf_img)
h = shape[0] h = shape[0]
...@@ -293,6 +308,8 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None): ...@@ -293,6 +308,8 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns: Returns:
Image tensor with shape `target_image_shape`. Image tensor with shape `target_image_shape`.
""" """
if alpha_override > 0:
alpha = alpha_override
n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None) n = tf.random_uniform(shape=[], minval=0, maxval=alpha, dtype=tf.float32, seed=seed, name=None)
shape = tf.shape(tf_img) shape = tf.shape(tf_img)
...@@ -350,6 +367,8 @@ def random_cutout(tf_img, alpha=0.1, seed=None): ...@@ -350,6 +367,8 @@ def random_cutout(tf_img, alpha=0.1, seed=None):
Returns: Returns:
Cutout Image tensor Cutout Image tensor
""" """
if alpha_override > 0:
alpha = alpha_override
# get img shape # get img shape
shape = tf.shape(tf_img) shape = tf.shape(tf_img)
......
...@@ -42,7 +42,6 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm ...@@ -42,7 +42,6 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
with tf.name_scope('ImageSummaries'), tf.device('/cpu:0'): with tf.name_scope('ImageSummaries'), tf.device('/cpu:0'):
tf.summary.image("reals_pre-augment", pre) tf.summary.image("reals_pre-augment", pre)
tf.summary.image("reals_post-augment", post) tf.summary.image("reals_post-augment", post)
with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
s = tf.shape(x) s = tf.shape(x)
y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])
...@@ -303,10 +302,21 @@ def training_loop( ...@@ -303,10 +302,21 @@ def training_loop(
G.setup_weight_histograms(); D.setup_weight_histograms() G.setup_weight_histograms(); D.setup_weight_histograms()
metrics = metric_base.MetricGroup(metric_arg_list) metrics = metric_base.MetricGroup(metric_arg_list)
print('Training for %d kimg...\n' % total_kimg)
if spatial_augmentations: if spatial_augmentations:
print('Augmenting fakes and reals\n') print('Augmenting fakes and reals')
alpha_override = float(os.environ.get('SPATIAL_AUGS_ALPHA', '0'))
if alpha_override == 0.0:
print('Augmentation alpha at default setting of 0.1 - change by setting SPATIAL_AUGS_ALPHA environment variable')
else:
if alpha_override >= 1:
alpha_override = 0.999
print(f'Augmentation alpha set to {alpha_override}')
if save_image_summaries:
print('Saving image summaries to tensorboard')
print('Training for %d kimg...\n' % total_kimg)
dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
maintenance_time = dnnlib.RunContext.get().get_last_update_interval() maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
cur_nimg = int(resume_kimg * 1000) cur_nimg = int(resume_kimg * 1000)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment