diff --git a/training/training_loop.py b/training/training_loop.py index 2a9cde2addf8b3d92eb718bbc9f52e312d27a1dc..12f2691623a7d6e393a93875d885b056a977b75b 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -47,7 +47,7 @@ def apply_random_aug(x): # x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(5))), lambda: misc.random_cutout(x), lambda: tf.identity(x)) return x -def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net, dshape): +def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net): with tf.name_scope('DynamicRange'): x = tf.cast(x, tf.float32) x = misc.adjust_dynamic_range(x, drange_data, drange_net) @@ -59,10 +59,16 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm x = tf.where(tf.random_uniform([tf.shape(x)[0]]) < 0.5, x, tf.reverse(x, [2])) if spatial_augmentations: with tf.name_scope('SpatialAugmentations'): - x.set_shape(dshape) - x = tf.transpose(x, [0, 2, 3, 1]) - x = tf.map_fn(apply_random_aug, x) - x = tf.transpose(x, [0, 3, 1, 2]) + pre = tf.transpose(x, [0, 2, 3, 1]) + post = tf.map_fn(apply_random_aug, pre) + x = tf.transpose(post, [0, 3, 1, 2]) + with tf.name_scope('ImageSummaries'), tf.device('/cpu:0'): + tf.summary.image("pre-augment", pre) + tf.summary.image("post-augment", post) + + + + with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. s = tf.shape(x) y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) @@ -271,7 +277,7 @@ def training_loop( reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() - reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, spatial_augmentations, training_set.dynamic_range, drange_net, reals_var.shape) + reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, spatial_augmentations, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)]