diff --git a/training/training_loop.py b/training/training_loop.py index daf1130004e4a20a4a89c57878002ecc68bb8ba1..51a8f787d9f8d32bae25f5c9644a72a23ea092b3 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -19,6 +19,46 @@ from metrics import metric_base #---------------------------------------------------------------------------- # Just-in-time processing of training images before feeding them to the networks. +def _random_choice(inputs, n_samples=1): + """ + With replacement. + Params: + inputs (Tensor): Shape [n_states, n_features] + n_samples (int): The number of random samples to take. + Returns: + sampled_inputs (Tensor): Shape [n_samples, n_features] + """ + # (1, n_states) since multinomial requires 2D logits. + uniform_log_prob = tf.expand_dims(tf.zeros(tf.shape(inputs)[0]), 0) + + ind = tf.multinomial(uniform_log_prob, n_samples) + ind = tf.squeeze(ind, 0, name="random_choice_ind") # (n_samples,) + + return tf.gather(inputs, ind, name="random_choice") + +def apply_random_aug(x): + with tf.name_scope('SpatialAugmentations'): + choice = np.random.randint(6) + print(choice) + if choice == 0: + print('zooming in') + x = misc.zoom_in(x) + elif choice == 1: + print('zooming out') + x = misc.zoom_out(x) + elif choice == 2: + print('x trans') + x = misc.X_translate(x) + elif choice == 3: + print('y trans') + x = misc.Y_translate(x) + elif choice == 4: + print('xy trans') + x = misc.XY_translate(x) + elif choice == 5: + print('cutout') + x = misc.random_cutout(x) + return x def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net): with tf.name_scope('DynamicRange'): @@ -32,22 +72,15 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm x = tf.where(tf.random_uniform([tf.shape(x)[0]]) < 0.5, x, tf.reverse(x, [2])) if spatial_augmentations: with tf.name_scope('SpatialAugmentations'): - choices = ['zoom in'] - choice = choices[tf.random_uniform(shape=[], minval=0, maxval=len(choices), dtype=tf.int32, seed=None, name=None).eval()] - print(choice) - if choice == 'zoom in': - x = misc.zoom_in(x) - elif choice == 'zoom out': - x = misc.zoom_out(x) - elif choice == 'x_trans': - x = misc.X_translate(x) - elif choice == 'y_trans': - x = misc.Y_translate(x) - elif choice == 'xy_trans': - x = misc.XY_translate(x) - elif choice == 'cutout': - x = misc.random_cutout(x) - + imgs = tf.data.Dataset.from_tensor_slices(x) + grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(imgs) + misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals_test.jpg'), drange=[-1,1], + grid_size=grid_size) + + imgs = imgs.map(apply_random_aug) + grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(imgs) + misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals_augmented_test.jpg'), drange=[-1,1], + grid_size=grid_size) with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. s = tf.shape(x) y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])