如何記錄分析你的煉丹流程—可視化神器Wandb使用筆記【1】

本節主要記錄使用wandb記錄訓練曲線以及上傳一些格式的數據將其展示在wandb中以便分析的方法,略過註冊安裝部分(可使用pip intall wandb安裝,註冊相關issue可上網搜索),文章着重於wandb的基本用法。

初始化

首先創建在wandb頁面中中創建需要可視化的project,然後在代碼裏面只要指定好team和project,便可以把數據傳輸到對應的project下:

import wandb
# notes:一些文字描述實驗發現或備註,也可以在wandb網頁的individual experiment panel中添加
# dir:本地文件寫入的路徑,(環境變量WANDB_DIR或wandb.init的關鍵字參數dir)
run_dir = Path("../results") / all_args.project_name / all_args.experiment_name
if not run_dir.exists():
    os.makedirs(str(run_dir))
wandb.init(config=all_args,
               project=your_project_name,
               entity=your_team_name,
               notes=socket.gethostname(),
               name=all_args.experiment_name + "_" + str(all_args.seed), 
               dir=run_dir,
               group=all_args.scenario_name,
               job_type="training",
               reinit=True)

基本使用

wandb的核心功能就是跟蹤訓練過程,展示訓練流程以供我們觀察展示和分析,該節以黃世宇代碼示例圖為例,說明wandb如何使用wandb.log()做到展示包括訓練曲線、圖片、matplotlib可視化結果、視頻、表格、甚至html在內的不同結構的數據。(顯示媒體文件時不需要在本地進行文件讀寫,可以直接用wandb的函數將展示對象處理為對應的格式就可以顯示。)

訓練曲線展示

total_step_num = 1000
for step in range(total_step_num):
    wandb.log({'random_curve':step/100+random.random()},step=step)
    wandb.log({'log_curve': math.log(step+1)},step=step)
wandb.finish()

Matplotlib可視化展示

# figure就是一個圖,axes表示圖上的一個畫圖區域,一個圖上可以有多個畫圖區域,即一個圖上可以有多個子圖
# 用函數gcf()與gca()分別得到當前的figure與axes。(get current figure, get current axes)
x = np.arange(1, 11)
for step in range(4):
    frames = []
    y = step * x + step
    plt.title("Matplotlib Demo")
    plt.xlabel("x axis caption")
    plt.ylabel("y axis caption")
    plt.plot(x, y)
    wandb.log({"plt":wandb.Plotly(plt.gcf())},step=step)

圖片展示

env = gym.make("PongNoFrameskip-v4")
env.reset()
for step in range(4):
    frames = [] # 每個step輸出一個由4張圖片組成的列表
    for i in range(4):
        obs,r,done,_=env.step(env.action_space.sample())
        # wandb.Image將numpy arrays或PILImage的實例轉化為PNG以供展示
        frames.append(wandb.Image(obs, caption="Pong"))
    wandb.log({"frames": frames},step=step)
    if done:
        env.reset()

視頻展示

env = gym.make("PongNoFrameskip-v4")
for episode in range(3):
    env.reset()
    done = False
    frames = []
    while not done:
        for _ in range(4):
            obs,r,done,_=env.step(env.action_space.sample())
            if done:
                break
        frames.append(obs)
    sequence = np.stack(frames, -1).transpose(3,2,0,1) # time, channels, height, width
    print(sequence.shape)
    video = wandb.Video(sequence, fps=10, format="gif",caption="Pong")
    wandb.log({"video": video},step=episode)

表格展示

columns = ["Name", "Age", "Score"]

data = [["ZhuZhu", 1, 0], ["MaoMao",2,1]]
table = wandb.Table(data=data, columns=columns)
wandb.log({"table": table})
wandb.finish()

展示html

html1 = wandb.Html('<a href="//tartrl.cn">TARTRL</a>')
html2 = wandb.Html(open('test.html'))
wandb.log({"html1": html1,"html2":html2})
wandb.finish()

參考

wandb使用教程(一):基礎用法 – 知乎 (zhihu.com)