diff --git a/training/training_loop.py b/training/training_loop.py index c287fa7684ad0597b378c07c0b7eddc615bac7c6..7c4454fc6a9e818462a53f8873194c1a31c87230 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -73,9 +73,17 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm x = tf.where(tf.random_uniform([tf.shape(x)[0]]) < 0.5, x, tf.reverse(x, [2])) if spatial_augmentations: with tf.name_scope('SpatialAugmentations'): - s = tf.shape(x) - x.set_shape(s) + print('PRE TRANSPOSE :') + print(x.get_shape()) + x = tf.transpose(x, [0, 2, 3, 1]) + print('POST TRANSPOSE :') + print(x.get_shape()) x = tf.map_fn(apply_random_aug, x) + print('POST AUGMENT :') + print(x.get_shape()) + x = tf.transpose(x, [0, 3, 1, 2]) + print('POST AUGMENT/TRANSPOSE :') + print(x.get_shape()) with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. s = tf.shape(x) y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])