diff --git a/training/training_loop.py b/training/training_loop.py
index daf1130004e4a20a4a89c57878002ecc68bb8ba1..51a8f787d9f8d32bae25f5c9644a72a23ea092b3 100755
--- a/training/training_loop.py
+++ b/training/training_loop.py
@@ -19,6 +19,46 @@ from metrics import metric_base
 
 #----------------------------------------------------------------------------
 # Just-in-time processing of training images before feeding them to the networks.
+def _random_choice(inputs, n_samples=1):
+    """
+    With replacement.
+    Params:
+      inputs (Tensor): Shape [n_states, n_features]
+      n_samples (int): The number of random samples to take.
+    Returns:
+      sampled_inputs (Tensor): Shape [n_samples, n_features]
+    """
+    # (1, n_states) since multinomial requires 2D logits.
+    uniform_log_prob = tf.expand_dims(tf.zeros(tf.shape(inputs)[0]), 0)
+
+    ind = tf.multinomial(uniform_log_prob, n_samples)
+    ind = tf.squeeze(ind, 0, name="random_choice_ind")  # (n_samples,)
+
+    return tf.gather(inputs, ind, name="random_choice")
+
+def apply_random_aug(x):
+    with tf.name_scope('SpatialAugmentations'):
+        choice = np.random.randint(6)
+        print(choice)
+        if choice == 0:
+            print('zooming in')
+            x = misc.zoom_in(x)
+        elif choice == 1:
+            print('zooming out')
+            x = misc.zoom_out(x)
+        elif choice == 2:
+            print('x trans')
+            x = misc.X_translate(x)
+        elif choice == 3:
+            print('y trans')
+            x = misc.Y_translate(x)
+        elif choice == 4:
+            print('xy trans')
+            x = misc.XY_translate(x)
+        elif choice == 5:
+            print('cutout')
+            x = misc.random_cutout(x)
+        return x
 
 def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net):
     with tf.name_scope('DynamicRange'):
@@ -32,22 +72,15 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
             x = tf.where(tf.random_uniform([tf.shape(x)[0]]) < 0.5, x, tf.reverse(x, [2]))
     if spatial_augmentations:
         with tf.name_scope('SpatialAugmentations'):
-            choices = ['zoom in']
-            choice = choices[tf.random_uniform(shape=[], minval=0, maxval=len(choices), dtype=tf.int32, seed=None, name=None).eval()]
-            print(choice)
-            if choice == 'zoom in':
-                x = misc.zoom_in(x)
-            elif choice == 'zoom out':
-                x = misc.zoom_out(x)
-            elif choice == 'x_trans':
-                x = misc.X_translate(x)
-            elif choice == 'y_trans':
-                x = misc.Y_translate(x)
-            elif choice == 'xy_trans':
-                x = misc.XY_translate(x)
-            elif choice == 'cutout':
-                x = misc.random_cutout(x)
-
+            imgs = tf.data.Dataset.from_tensor_slices(x)
+            grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(imgs)
+            misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals_test.jpg'), drange=[-1,1],
+                                 grid_size=grid_size)
+
+            imgs = imgs.map(apply_random_aug)
+            grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(imgs)
+            misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals_augmented_test.jpg'), drange=[-1,1],
+                                 grid_size=grid_size)
     with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
         s = tf.shape(x)
         y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])