diff --git a/training/training_loop.py b/training/training_loop.py
index c287fa7684ad0597b378c07c0b7eddc615bac7c6..7c4454fc6a9e818462a53f8873194c1a31c87230 100755
--- a/training/training_loop.py
+++ b/training/training_loop.py
@@ -73,9 +73,17 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
             x = tf.where(tf.random_uniform([tf.shape(x)[0]]) < 0.5, x, tf.reverse(x, [2]))
     if spatial_augmentations:
         with tf.name_scope('SpatialAugmentations'):
-            s = tf.shape(x)
-            x.set_shape(s)
+            print('PRE TRANSPOSE :')
+            print(x.get_shape())
+            x = tf.transpose(x, [0, 2, 3, 1])
+            print('POST TRANSPOSE :')
+            print(x.get_shape())
             x = tf.map_fn(apply_random_aug, x)
+            print('POST AUGMENT :')
+            print(x.get_shape())
+            x = tf.transpose(x, [0, 3, 1, 2])
+            print('POST AUGMENT/TRANSPOSE :')
+            print(x.get_shape())
     with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
         s = tf.shape(x)
         y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])