tf.nn.top_k()
- 2019 年 10 月 28 日
- 筆記
版權聲明:本文為部落客原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接和本聲明。
本文鏈接:https://blog.csdn.net/weixin_36670529/article/details/102711031
一、函數原型
tf.nn.top_k( input, k=1, sorted=True, name=None )
為了找到輸入的張量的最後的一個維度的最大的k個值和它的下標!
如果輸入的是一個向量,也就是rank=1,找到最大的k個數在這個向量,則輸出最大的k個數字和最大的這k個數字的下標。如果輸入的張量是一個更高rank的矩陣,那麼我們只要找到每一行的最大的k個數字,以及他們的下標。如果兩個元素相同,那麼低一點的下標先出現。
參數:
- input:輸入的tensor,不能是array這些啊!要麼輸入1-D,要是更高維度必須保證最後的一個維度長度必須大於等於K
- k:0-D的int32的數字張量。
- sorted:如果sorted=True,那麼選出來的k個數字就需要按照降序的順序排序
- name:可選項,也就是這個操作的名字
返回:
- values:也就是每一行的最大的k個數字
- indices:這裡的下標是在輸入的張量的最後一個維度的下標
二、例子
import tensorflow as tf import numpy as np #選出每一行的最大的前兩個數字 #返回的是最大的k個數字,同時返回的是最大的k個數字在最後的一個維度的下標 a=tf.constant(np.random.rand(3,4)) b=tf.nn.top_k(a,k=2) with tf.Session() as sess: print(sess.run(a)) print(sess.run(b)) Output: [[0.73731748 0.13455566 0.20236765 0.92909052] [0.7923021 0.46949081 0.31521194 0.2999236 ] [0.19102823 0.01301476 0.70615716 0.68501807]] TopKV2(values=array([[0.92909052, 0.73731748], [0.7923021 , 0.46949081], [0.70615716, 0.68501807]]), indices=array([[3, 0], [0, 1], [2, 3]], dtype=int32))