tf.nn.top_k()

  • 2019 年 10 月 28 日
  • 筆記

版權聲明:本文為部落客原創文章,遵循 CC 4.0 BY-SA 版權協議,轉載請附上原文出處鏈接和本聲明。

本文鏈接:https://blog.csdn.net/weixin_36670529/article/details/102711031

一、函數原型

tf.nn.top_k(      input,      k=1,      sorted=True,      name=None  )

為了找到輸入的張量的最後的一個維度的最大的k個值和它的下標!

如果輸入的是一個向量,也就是rank=1,找到最大的k個數在這個向量,則輸出最大的k個數字和最大的這k個數字的下標。如果輸入的張量是一個更高rank的矩陣,那麼我們只要找到每一行的最大的k個數字,以及他們的下標。如果兩個元素相同,那麼低一點的下標先出現。

參數:

  • input:輸入的tensor,不能是array這些啊!要麼輸入1-D,要是更高維度必須保證最後的一個維度長度必須大於等於K
  • k:0-D的int32的數字張量。
  • sorted:如果sorted=True,那麼選出來的k個數字就需要按照降序的順序排序
  • name:可選項,也就是這個操作的名字

返回:

  • values:也就是每一行的最大的k個數字
  • indices:這裡的下標是在輸入的張量的最後一個維度的下標

二、例子

import tensorflow as tf  import numpy as np    #選出每一行的最大的前兩個數字  #返回的是最大的k個數字,同時返回的是最大的k個數字在最後的一個維度的下標  a=tf.constant(np.random.rand(3,4))  b=tf.nn.top_k(a,k=2)  with tf.Session() as sess:      print(sess.run(a))      print(sess.run(b))      Output:        [[0.73731748 0.13455566 0.20236765 0.92909052]       [0.7923021  0.46949081 0.31521194 0.2999236 ]       [0.19102823 0.01301476 0.70615716 0.68501807]]      TopKV2(values=array([[0.92909052, 0.73731748],             [0.7923021 , 0.46949081],             [0.70615716, 0.68501807]]), indices=array([[3, 0],             [0, 1],             [2, 3]], dtype=int32))