Skip to content
Snippets Groups Projects
Commit 938ac88f authored by Jeremiah's avatar Jeremiah
Browse files

fixed encoder

parent 358a0468
Branches
No related tags found
No related merge requests found
......@@ -389,6 +389,7 @@ class Network:
minibatch_size: int = None,
num_gpus: int = 1,
assume_frozen: bool = False,
custom_inputs: Any = None,
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
"""Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
......@@ -404,6 +405,7 @@ class Network:
minibatch_size: Maximum minibatch size to use, None = disable batching.
num_gpus: Number of GPUs to use.
assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
custom_inputs: Allow to use another tensor as input instead of default placeholders.
dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
"""
assert len(in_arrays) == self.num_inputs
......@@ -428,6 +430,11 @@ class Network:
# Build graph.
if key not in self._run_cache:
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
if custom_inputs is not None:
with tf.device("/gpu:0"):
in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
else:
with tf.device("/cpu:0"):
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
......
......@@ -10,6 +10,7 @@ import dnnlib.tflib as tflib
import pretrained_networks
from encoder.generator_model import Generator
from encoder.perceptual_model import PerceptualModel
from encoder.perceptual_model import load_images
from keras.models import load_model
from keras.applications.resnet50 import preprocess_input
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment