I'm trying to train a small TF2.x model on 4 GPUs (AWS g4dn.12xlarge) that takes both dense and sparse tensors as its input. Once I tried without sparse features and just used dense features, my distributed training code worked well without any performance degradation. After including the sparse features, however, I found numerous unexpected chunks on the TensorBoard Profiler's trace_viewer. Attached the profiler screenshot.
The main problem is that, although it seems all the GPUs computes their given batches well, there is a large timespan between a pair of computation blocks on the host side. There are 17x4 of EagerExecute:DeserializeSparse with the terminal ops of _Send input 0 from /job:localhost/replica:0/task:0/device:GPU:{gpu_number} to /job:localhost/replica:0/task:0/device:CPU:0. Here, 17 is the number of sparse features that the model receives, and 4 is the num of GPUs being utilized. Plus, tons of MemcpyD2H (small pink blocks at the screen shot) are occupying each GPU, not parallelized. That large period of time is about x6 of the actual forward pass.
Below is how the model treats sparse tensor inputs:
def call(self, inputs: tf.sparse.SparseTensor):
with tf.device("\cpu:0"):
x = self.hash_inputs_from_static_hash_table(inputs)
x = self.embedding_lookup_sparse(x)
return self.prediction_head(x)
The data can never be big (batch size = 128 per replica, sparse feature embedding dimension is <10), and I tried to move all sparse-related operations to CPU not to burden GPUs, but the problem persists just as the same as I didn't move those ops to CPU manually.
I want to know why those chunks appear after the GPU computations, and hopefully remove them to fully benefit from distributed training with multiple GPUs.
Seems like I'm still missing something that can be optimized and this situation might not that unique in distributed training, so asking for help for broader audience.