python - Multiplication of tensor with batch data and a matrix -
given tensor a of shape [?, n, m] , tensor w of shape [m, m], want multiply each tensor a of shape [n,m] of a w resulting in tensor of shape [?, n, m].
i thought somehow reshaping w shape [tf.shape(a)[0], n, m], not result in tensor of shape [?, n, m].
yes, can indeed reshaping:
tf.reshape(tf.matmul(tf.reshape(a, [-1, m]), w), [-1, n, m])
Comments
Post a Comment