diff --git a/training/training_loop.py b/training/training_loop.py
index 2a9cde2addf8b3d92eb718bbc9f52e312d27a1dc..12f2691623a7d6e393a93875d885b056a977b75b 100755
--- a/training/training_loop.py
+++ b/training/training_loop.py
@@ -47,7 +47,7 @@ def apply_random_aug(x):
         # x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(5))), lambda: misc.random_cutout(x), lambda: tf.identity(x))
         return x
 
-def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net, dshape):
+def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net):
     with tf.name_scope('DynamicRange'):
         x = tf.cast(x, tf.float32)
         x = misc.adjust_dynamic_range(x, drange_data, drange_net)
@@ -59,10 +59,16 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
             x = tf.where(tf.random_uniform([tf.shape(x)[0]]) < 0.5, x, tf.reverse(x, [2]))
     if spatial_augmentations:
         with tf.name_scope('SpatialAugmentations'):
-            x.set_shape(dshape)
-            x = tf.transpose(x, [0, 2, 3, 1])
-            x = tf.map_fn(apply_random_aug, x)
-            x = tf.transpose(x, [0, 3, 1, 2])
+            pre = tf.transpose(x, [0, 2, 3, 1])
+            post = tf.map_fn(apply_random_aug, pre)
+            x = tf.transpose(post, [0, 3, 1, 2])
+        with tf.name_scope('ImageSummaries'), tf.device('/cpu:0'):
+            tf.summary.image("pre-augment", pre)
+            tf.summary.image("post-augment", post)
+
+
+
+
     with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
         s = tf.shape(x)
         y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])
@@ -271,7 +277,7 @@ def training_loop(
                 reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape))
                 labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size]))
                 reals_write, labels_write = training_set.get_minibatch_tf()
-                reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, spatial_augmentations, training_set.dynamic_range, drange_net, reals_var.shape)
+                reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, spatial_augmentations, training_set.dynamic_range, drange_net)
                 reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                 labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                 data_fetch_ops += [tf.assign(reals_var, reals_write)]