Skip to content
Snippets Groups Projects
Unverified Commit 90d54824 authored by Peter Baylies's avatar Peter Baylies Committed by GitHub
Browse files

Update runway_model.py

parent 0831b541
No related branches found
No related tags found
No related merge requests found
...@@ -20,18 +20,23 @@ def setup(opts): ...@@ -20,18 +20,23 @@ def setup(opts):
generate_inputs = { generate_inputs = {
'z': runway.vector(512, sampling_std=0.5), 'z': runway.vector(512, sampling_std=0.5),
'truncation': runway.number(min=0, max=1, default=0.8, step=0.01) 'label': runway.number(min=0, max=100000, default=0, step=1), # generate random labels
'scale': runway.number(min=-2, max=2, default=0, step=0.05), # magnitude of labels - 0 = no labels
'truncation': runway.number(min=-1.5, max=1.5, default=1, step=0.05)
} }
@runway.command('generate', inputs=generate_inputs, outputs={'image': runway.image}) @runway.command('generate', inputs=generate_inputs, outputs={'image': runway.image})
def convert(model, inputs): def convert(model, inputs):
z = inputs['z'] z = inputs['z']
label = int(inputs['label'])
scale = inputs['scale']
truncation = inputs['truncation'] truncation = inputs['truncation']
latents = z.reshape((1, 512)) latents = z.reshape((1, 512))
images = model.run(latents, None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt) labels = scale * np.random.RandomState(label).randn(167)
labels = labels.reshape((1,167)).astype(np.float32)
images = model.run(latents, labels, truncation_psi=truncation, randomize_noise=False, output_transform=fmt)
output = np.clip(images[0], 0, 255).astype(np.uint8) output = np.clip(images[0], 0, 255).astype(np.uint8)
return {'image': output} return {'image': output}
if __name__ == '__main__': if __name__ == '__main__':
runway.run() runway.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment