diff --git a/training/training_loop.py b/training/training_loop.py index 9a652093e2f783e4c1e80c605c98644597e76534..a31d9edabdaf94077e45cc343286cc25cae276bd 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -38,32 +38,13 @@ def _random_choice(inputs, n_samples=1): def apply_random_aug(x): with tf.name_scope('SpatialAugmentations'): - choice = tf.random_uniform([], 0, 2, tf.int32) + choice = tf.random_uniform([], 0, 6, tf.int32) x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(0))), lambda: misc.zoom_in(x), lambda: tf.identity(x)) x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(1))), lambda: misc.zoom_out(x), lambda: tf.identity(x)) x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(2))), lambda: misc.X_translate(x), lambda: tf.identity(x)) x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(3))), lambda: misc.Y_translate(x), lambda: tf.identity(x)) x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(4))), lambda: misc.XY_translate(x), lambda: tf.identity(x)) x = tf.cond(tf.reduce_all(tf.equal(choice, tf.constant(5))), lambda: misc.random_cutout(x), lambda: tf.identity(x)) - - # if : - # print('zooming in') - # x = misc.zoom_in(x) - # elif tf.reduce_all(tf.equal(choice, tf.constant(0))): - # print('zooming out') - # x = misc.zoom_out(x) - # elif choice == 2: - # print('x trans') - # x = misc.X_translate(x) - # elif choice == 3: - # print('y trans') - # x = misc.Y_translate(x) - # elif choice == 4: - # print('xy trans') - # x = misc.XY_translate(x) - # elif choice == 5: - # print('cutout') - # x = misc.random_cutout(x) return x def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augmentations, drange_data, drange_net, dshape):