Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
stylegan2
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
dsenkin
stylegan2
Commits
cd6e0464
Commit
cd6e0464
authored
5 years ago
by
sdtblck
Browse files
Options
Downloads
Patches
Plain Diff
finishing touches to zhao et al augs
parent
1c81c269
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
training/loss.py
+7
-4
7 additions, 4 deletions
training/loss.py
training/misc.py
+20
-1
20 additions, 1 deletion
training/misc.py
training/training_loop.py
+14
-4
14 additions, 4 deletions
training/training_loop.py
with
41 additions
and
9 deletions
training/loss.py
+
7
−
4
View file @
cd6e0464
...
@@ -208,10 +208,13 @@ def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_
...
@@ -208,10 +208,13 @@ def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_
pl_latents
=
tf
.
random_normal
([
pl_minibatch
]
+
G
.
input_shapes
[
0
][
1
:])
pl_latents
=
tf
.
random_normal
([
pl_minibatch
]
+
G
.
input_shapes
[
0
][
1
:])
pl_labels
=
training_set
.
get_random_labels_tf
(
pl_minibatch
)
pl_labels
=
training_set
.
get_random_labels_tf
(
pl_minibatch
)
fake_images_out
,
fake_dlatents_out
=
G
.
get_output_for
(
pl_latents
,
pl_labels
,
is_training
=
True
,
return_dlatents
=
True
)
fake_images_out
,
fake_dlatents_out
=
G
.
get_output_for
(
pl_latents
,
pl_labels
,
is_training
=
True
,
return_dlatents
=
True
)
if
augment
:
# TODO: applying augmentations here fails with the following error:
fake_images_out_pre_augment
=
tf
.
transpose
(
fake_images_out
,
[
0
,
2
,
3
,
1
])
# TypeError: Second-order gradient for while loops not supported.
fake_images_out_post_augment
=
tf
.
map_fn
(
misc
.
apply_random_aug
,
fake_images_out_pre_augment
)
# setting pl_minibatch_shrink to 1 would work - but will have a higher memory usage
fake_images_out
=
tf
.
transpose
(
fake_images_out_post_augment
,
[
0
,
3
,
1
,
2
])
# if augment:
# fake_images_out_pre_augment = tf.transpose(fake_images_out, [0, 2, 3, 1])
# fake_images_out_post_augment = tf.map_fn(misc.apply_random_aug, fake_images_out_pre_augment)
# fake_images_out = tf.transpose(fake_images_out_post_augment, [0, 3, 1, 2])
# Compute |J*y|.
# Compute |J*y|.
pl_noise
=
tf
.
random_normal
(
tf
.
shape
(
fake_images_out
))
/
np
.
sqrt
(
np
.
prod
(
G
.
output_shape
[
2
:]))
pl_noise
=
tf
.
random_normal
(
tf
.
shape
(
fake_images_out
))
/
np
.
sqrt
(
np
.
prod
(
G
.
output_shape
[
2
:]))
...
...
This diff is collapsed.
Click to expand it.
training/misc.py
+
20
−
1
View file @
cd6e0464
...
@@ -95,6 +95,14 @@ def convert_to_pil_image(image, drange=[0, 1]):
...
@@ -95,6 +95,14 @@ def convert_to_pil_image(image, drange=[0, 1]):
def
save_image_grid
(
images
,
filename
,
drange
=
[
0
,
1
],
grid_size
=
None
):
def
save_image_grid
(
images
,
filename
,
drange
=
[
0
,
1
],
grid_size
=
None
):
convert_to_pil_image
(
create_image_grid
(
images
,
grid_size
),
drange
).
save
(
filename
)
convert_to_pil_image
(
create_image_grid
(
images
,
grid_size
),
drange
).
save
(
filename
)
# ----------------------------------------------------------------------------
# Image Augmentations.
alpha_override
=
float
(
os
.
environ
.
get
(
'
SPATIAL_AUGS_ALPHA
'
,
'
0
'
))
if
alpha_override
>=
1
:
alpha_override
=
0.999
elif
alpha_override
==
0.0
:
alpha_override
=
0
def
apply_mirror_augment
(
minibatch
):
def
apply_mirror_augment
(
minibatch
):
mask
=
np
.
random
.
rand
(
minibatch
.
shape
[
0
])
<
0.5
mask
=
np
.
random
.
rand
(
minibatch
.
shape
[
0
])
<
0.5
...
@@ -142,6 +150,8 @@ def zoom_in(tf_img, alpha=0.1, target_image_shape=None, seed=None):
...
@@ -142,6 +150,8 @@ def zoom_in(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns:
Returns:
Image tensor with shape `target_image_shape`.
Image tensor with shape `target_image_shape`.
"""
"""
if
alpha_override
>
0
:
alpha
=
alpha_override
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
1
-
alpha
,
maxval
=
1
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
1
-
alpha
,
maxval
=
1
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
shape
=
tf
.
shape
(
tf_img
)
shape
=
tf
.
shape
(
tf_img
)
h
=
shape
[
0
]
h
=
shape
[
0
]
...
@@ -181,7 +191,8 @@ def zoom_out(tf_img, alpha=0.1, target_image_shape=None, seed=None):
...
@@ -181,7 +191,8 @@ def zoom_out(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns:
Returns:
Image tensor with shape `target_image_shape`.
Image tensor with shape `target_image_shape`.
"""
"""
if
alpha_override
>
0
:
alpha
=
alpha_override
# Set params
# Set params
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
...
@@ -230,6 +241,8 @@ def X_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
...
@@ -230,6 +241,8 @@ def X_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns:
Returns:
Image tensor with shape `target_image_shape`.
Image tensor with shape `target_image_shape`.
"""
"""
if
alpha_override
>
0
:
alpha
=
alpha_override
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
shape
=
tf
.
shape
(
tf_img
)
shape
=
tf
.
shape
(
tf_img
)
...
@@ -261,6 +274,8 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
...
@@ -261,6 +274,8 @@ def XY_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns:
Returns:
Image tensor with shape `target_image_shape`.
Image tensor with shape `target_image_shape`.
"""
"""
if
alpha_override
>
0
:
alpha
=
alpha_override
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
shape
=
tf
.
shape
(
tf_img
)
shape
=
tf
.
shape
(
tf_img
)
h
=
shape
[
0
]
h
=
shape
[
0
]
...
@@ -293,6 +308,8 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
...
@@ -293,6 +308,8 @@ def Y_translate(tf_img, alpha=0.1, target_image_shape=None, seed=None):
Returns:
Returns:
Image tensor with shape `target_image_shape`.
Image tensor with shape `target_image_shape`.
"""
"""
if
alpha_override
>
0
:
alpha
=
alpha_override
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
n
=
tf
.
random_uniform
(
shape
=
[],
minval
=
0
,
maxval
=
alpha
,
dtype
=
tf
.
float32
,
seed
=
seed
,
name
=
None
)
shape
=
tf
.
shape
(
tf_img
)
shape
=
tf
.
shape
(
tf_img
)
...
@@ -350,6 +367,8 @@ def random_cutout(tf_img, alpha=0.1, seed=None):
...
@@ -350,6 +367,8 @@ def random_cutout(tf_img, alpha=0.1, seed=None):
Returns:
Returns:
Cutout Image tensor
Cutout Image tensor
"""
"""
if
alpha_override
>
0
:
alpha
=
alpha_override
# get img shape
# get img shape
shape
=
tf
.
shape
(
tf_img
)
shape
=
tf
.
shape
(
tf_img
)
...
...
This diff is collapsed.
Click to expand it.
training/training_loop.py
+
14
−
4
View file @
cd6e0464
...
@@ -42,7 +42,6 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
...
@@ -42,7 +42,6 @@ def process_reals(x, labels, lod, mirror_augment, mirror_augment_v, spatial_augm
with
tf
.
name_scope
(
'
ImageSummaries
'
),
tf
.
device
(
'
/cpu:0
'
):
with
tf
.
name_scope
(
'
ImageSummaries
'
),
tf
.
device
(
'
/cpu:0
'
):
tf
.
summary
.
image
(
"
reals_pre-augment
"
,
pre
)
tf
.
summary
.
image
(
"
reals_pre-augment
"
,
pre
)
tf
.
summary
.
image
(
"
reals_post-augment
"
,
post
)
tf
.
summary
.
image
(
"
reals_post-augment
"
,
post
)
with
tf
.
name_scope
(
'
FadeLOD
'
):
# Smooth crossfade between consecutive levels-of-detail.
with
tf
.
name_scope
(
'
FadeLOD
'
):
# Smooth crossfade between consecutive levels-of-detail.
s
=
tf
.
shape
(
x
)
s
=
tf
.
shape
(
x
)
y
=
tf
.
reshape
(
x
,
[
-
1
,
s
[
1
],
s
[
2
]
//
2
,
2
,
s
[
3
]
//
2
,
2
])
y
=
tf
.
reshape
(
x
,
[
-
1
,
s
[
1
],
s
[
2
]
//
2
,
2
,
s
[
3
]
//
2
,
2
])
...
@@ -303,10 +302,21 @@ def training_loop(
...
@@ -303,10 +302,21 @@ def training_loop(
G
.
setup_weight_histograms
();
D
.
setup_weight_histograms
()
G
.
setup_weight_histograms
();
D
.
setup_weight_histograms
()
metrics
=
metric_base
.
MetricGroup
(
metric_arg_list
)
metrics
=
metric_base
.
MetricGroup
(
metric_arg_list
)
print
(
'
Training for %d kimg...
\n
'
%
total_kimg
)
if
spatial_augmentations
:
if
spatial_augmentations
:
print
(
'
Augmenting fakes and reals
\n
'
)
print
(
'
Augmenting fakes and reals
'
)
alpha_override
=
float
(
os
.
environ
.
get
(
'
SPATIAL_AUGS_ALPHA
'
,
'
0
'
))
if
alpha_override
==
0.0
:
print
(
'
Augmentation alpha at default setting of 0.1 - change by setting SPATIAL_AUGS_ALPHA environment variable
'
)
else
:
if
alpha_override
>=
1
:
alpha_override
=
0.999
print
(
f
'
Augmentation alpha set to
{
alpha_override
}
'
)
if
save_image_summaries
:
print
(
'
Saving image summaries to tensorboard
'
)
print
(
'
Training for %d kimg...
\n
'
%
total_kimg
)
dnnlib
.
RunContext
.
get
().
update
(
''
,
cur_epoch
=
resume_kimg
,
max_epoch
=
total_kimg
)
dnnlib
.
RunContext
.
get
().
update
(
''
,
cur_epoch
=
resume_kimg
,
max_epoch
=
total_kimg
)
maintenance_time
=
dnnlib
.
RunContext
.
get
().
get_last_update_interval
()
maintenance_time
=
dnnlib
.
RunContext
.
get
().
get_last_update_interval
()
cur_nimg
=
int
(
resume_kimg
*
1000
)
cur_nimg
=
int
(
resume_kimg
*
1000
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment