Pytorch項目基本結構

梳理一下Pytorch項目的基本結構(其實TF的也差不多是這樣,這種思路可以遷移到別的深度學習框架中)

結構樹

——-checkpoints #存放訓練完成的模型文件

​ —-xxx.pkl #模型文件

——–data #存放數據文件(如txt)或者數據預處理文件

​ —__ init __.py

​ —xxx.txt #數據

​ —dataset.py #數據集相關

​ —get_data.sh #一般用於下載某些數據

——–models #存放模型,一般一個模型對應一個.py文件

​ —__ init __.py

​ —xxxNet.py

​ —xxxModel.py

——–utils #存放一些工具函數,如可視化等

​ —__ init __.py

​ —visualize.py

——–config.py #配置文件

——–train.py #用於訓練模型,可視為主文件

——–test.py #用於測試模型

流程

1、獲取數據

使用.sh文件下載或者其他方法獲得數據

2、數據載入

一般會有一個文件把數據處理成適合的格式,然後通過加載器(Dataloader)載入模型中使用,這個Dataloader可能是獨立的,也可能集成在train.py裏面

3、訓練

顧名思義,使用載入的數據對定義的模型進行訓練。這個過程基本上是使用train.py進行,結果是你會得到一個.pkl結尾的模型文件

4、測試

用一部分數據對訓練好的模型進行測試(這些數據可以來自之前導入的數據,也可以是新的數據),使用test.py進行,調用損失函數,打印日誌(就是你看到的那些在console里刷新的log)

5、使用模型

就是調用即可,先給出我們存放模型的位置,然後加載即可(沒有實操,後續再更新)

註:

  • 模型.py文件中,一般是用一個函數或者一個類來承載一個具體模型,其中定義着模型的不同層
  • train.py是工程的核心,裏面定義了訓練時需要的各項參數、訓練次數等重要信息