| | from keras.layers import concatenate |
| | from keras.layers.core import Lambda |
| | from keras.models import Model |
| |
|
| | import tensorflow as tf |
| |
|
| | def make_parallel(model, gpu_count): |
| | def get_slice(data, idx, parts): |
| | shape = tf.shape(data) |
| | size = tf.concat([ shape[:1] // parts, shape[1:] ],axis=0) |
| | stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ],axis=0) |
| | start = stride * idx |
| | return tf.slice(data, start, size) |
| |
|
| | outputs_all = [] |
| | for i in range(len(model.outputs)): |
| | outputs_all.append([]) |
| |
|
| | |
| | for i in range(gpu_count): |
| | with tf.device('/gpu:%d' % i): |
| | with tf.name_scope('tower_%d' % i) as scope: |
| |
|
| | inputs = [] |
| | |
| | for x in model.inputs: |
| | input_shape = tuple(x.get_shape().as_list())[1:] |
| | slice_n = Lambda(get_slice, output_shape=input_shape, arguments={'idx':i,'parts':gpu_count})(x) |
| | inputs.append(slice_n) |
| |
|
| | outputs = model(inputs) |
| | |
| | if not isinstance(outputs, list): |
| | outputs = [outputs] |
| | |
| | |
| | for l in range(len(outputs)): |
| | outputs_all[l].append(outputs[l]) |
| |
|
| | |
| | with tf.device('/cpu:0'): |
| | merged = [] |
| | for outputs in outputs_all: |
| | merged.append(concatenate(outputs, axis=0)) |
| | |
| | return Model(inputs=model.inputs, outputs=merged) |
| |
|
| |
|