首先,明确一点,tf.argmax可以认为就是np.argmax。tensorflow使用numpy实现的这个API。
简单的说,tf.argmax就是返回最大的那个数值所在的下标。 这个很好理解,只是tf.argmax()的参数让人有些迷惑,比如,tf.argmax(array, 1)和tf.argmax(array, 0)有啥区别呢? 这里面就涉及到一个概念:axis。上面例子中的1和0就是axis。我先笼统的解释这个问题,设置axis的主要原因是方便我们进行多个维度的计算。在实例面前,再多的语言都是苍白的呀,上例子!
比如:
test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])np.argmax(test, 0) #输出:array([3, 3, 1]np.argmax(test, 1) #输出:array([2, 2, 0, 0]