diff --git a/runway_model.py b/runway_model.py index 4bf35e067dff03ff5ea239db494f1663cb6be01e..a54e7fe2d356e68dc6848699fab3a08615a3c88f 100644 --- a/runway_model.py +++ b/runway_model.py @@ -20,18 +20,23 @@ def setup(opts): generate_inputs = { 'z': runway.vector(512, sampling_std=0.5), - 'truncation': runway.number(min=0, max=1, default=0.8, step=0.01) + 'label': runway.number(min=0, max=100000, default=0, step=1), # generate random labels + 'scale': runway.number(min=-2, max=2, default=0, step=0.05), # magnitude of labels - 0 = no labels + 'truncation': runway.number(min=-1.5, max=1.5, default=1, step=0.05) } @runway.command('generate', inputs=generate_inputs, outputs={'image': runway.image}) def convert(model, inputs): z = inputs['z'] + label = int(inputs['label']) + scale = inputs['scale'] truncation = inputs['truncation'] latents = z.reshape((1, 512)) - images = model.run(latents, None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt) + labels = scale * np.random.RandomState(label).randn(167) + labels = labels.reshape((1,167)).astype(np.float32) + images = model.run(latents, labels, truncation_psi=truncation, randomize_noise=False, output_transform=fmt) output = np.clip(images[0], 0, 255).astype(np.uint8) return {'image': output} - if __name__ == '__main__': runway.run()