Skip to content
Snippets Groups Projects
Commit 810359e7 authored by Peter Baylies's avatar Peter Baylies
Browse files

Merge remote-tracking branch 'shawwn/swarm' into swarm

parents a04797ac 10ac7ac5
Branches
No related tags found
No related merge requests found
...@@ -281,11 +281,12 @@ def run_wrapper(submit_config: SubmitConfig) -> None: ...@@ -281,11 +281,12 @@ def run_wrapper(submit_config: SubmitConfig) -> None:
else: else:
run_func_obj(**submit_config.run_func_kwargs) run_func_obj(**submit_config.run_func_kwargs)
if 'TPU_NAME' not in os.environ:
thunk()
else:
kws = submit_config.run_func_kwargs kws = submit_config.run_func_kwargs
tf_config = kws['tf_config'] if 'tf_config' in kws else {} tf_config = kws['tf_config'] if 'tf_config' in kws else {}
if 'TPU_NAME' not in os.environ or 'NO_SWARM' in os.environ:
tflib.init_tf(tf_config)
thunk()
else:
threads = [] threads = []
tflex.trainers = [] tflex.trainers = []
tpu_core_count = 8 tpu_core_count = 8
......
...@@ -16,6 +16,9 @@ from dnnlib.tflib.ops.fused_bias_act import fused_bias_act ...@@ -16,6 +16,9 @@ from dnnlib.tflib.ops.fused_bias_act import fused_bias_act
# NOTE: Do not import any application-specific modules here! # NOTE: Do not import any application-specific modules here!
# Specify all network parameters as kwargs. # Specify all network parameters as kwargs.
def _i(x): return tf.transpose(x, [0,2,3,1])
def _o(x): return tf.transpose(x, [0,3,1,2])
#---------------------------------------------------------------------------- #----------------------------------------------------------------------------
# Get/create weight tensor for a convolution or fully-connected layer. # Get/create weight tensor for a convolution or fully-connected layer.
...@@ -53,11 +56,11 @@ def conv2d_layer(x, fmaps, kernel, up=False, down=False, resample_kernel=None, g ...@@ -53,11 +56,11 @@ def conv2d_layer(x, fmaps, kernel, up=False, down=False, resample_kernel=None, g
assert kernel >= 1 and kernel % 2 == 1 assert kernel >= 1 and kernel % 2 == 1
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var) w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_var)
if up: if up:
x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) x = _o(upsample_conv_2d(_i(x), tf.cast(w, x.dtype), data_format='NHWC', k=resample_kernel))
elif down: elif down:
x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) x = _o(conv_downsample_2d(_i(x), tf.cast(w, x.dtype), data_format='NHWC', k=resample_kernel))
else: else:
x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME') x = _o(tf.nn.conv2d(_i(x), tf.cast(w, x.dtype), data_format='NHWC', strides=[1,1,1,1], padding='SAME'))
return x return x
#---------------------------------------------------------------------------- #----------------------------------------------------------------------------
...@@ -113,11 +116,11 @@ def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate ...@@ -113,11 +116,11 @@ def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate
# Convolution with optional up/downsampling. # Convolution with optional up/downsampling.
if up: if up:
x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) x = _o(upsample_conv_2d(_i(x), tf.cast(w, x.dtype), data_format='NHWC', k=resample_kernel))
elif down: elif down:
x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel) x = _o(conv_downsample_2d(_i(x), tf.cast(w, x.dtype), data_format='NHWC', k=resample_kernel))
else: else:
x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME') x = _o(tf.nn.conv2d(_i(x), tf.cast(w, x.dtype), data_format='NHWC', strides=[1,1,1,1], padding='SAME'))
# Reshape/scale output. # Reshape/scale output.
if fused_modconv: if fused_modconv:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment