用華為MindSpore框架訓練資料庫類型的數據集
技術背景
在前面一篇部落格我們講到三種用python去讀取一個文件的指定行的操作,最終給出的一個結論大概是,對於大型的數據而言,最快的找到指定行的方法是Linux系統自帶的sed
指令,那麼是否只有這一種辦法了呢?很顯然不是,之所以採用這些方法,是因為我們被局限在數據的存儲格式上,如果在處理數據或者產生數據的階段,就把數據按照特定的數據結構進行存儲,那麼就能夠大大的提高數據讀取的效率。這裡我們要介紹一個用sqlite3
來讀取數據用於MindSpore的訓練的案例,在有限的記憶體空間中避免完整的去載入整個數據集。
Sqlite3產生隨機數據
因為大部分的Python中是預裝了sqlite3的,這就避免了我們自己再去重複安裝的麻煩,比如Spark和PySpark就是安裝起來比較麻煩的典型案例,當然其性能和分散式的處理也是非常具有優越性的。這裡我們看一個用sqlite3產生訓練數據的案例,這個案例的原型來自於這篇部落格,其函數表達形式為:
\]
# store_data_to_db.py
import numpy as np
import sqlite3
from tqdm import trange
conn = sqlite3.connect('xyz.db') # 創建或者鏈接一個已有db文件
cur = conn.cursor()
try:
sql_1 = '''CREATE TABLE number
(i NUMBER,
x NUMBER,
y NUMBER,
z NUMBER);'''
cur.execute(sql_1) # 執行資料庫指令,創建一個新的表單
except:
pass
def get_data(num, a=2.0, b=3.0, c=5.0):
for _ in trange(num):
x = np.random.uniform(-1.0, 1.0)
y = np.random.uniform(-1.0, 1.0)
noise = np.random.normal(0, 0.03)
z = a * x ** 2 + b * y ** 3 + c + noise # 計算數據
# 將一行數據寫入資料庫
cur.execute("INSERT INTO number VALUES({},{},{},{})".format(_, x**2, y**3, z))
get_data(100) # 產生100組數據
conn.commit()
cur.close()
conn.close()
在這個案例中我們一共產生了100組的測試數據,運行過程如下:
(base) dechin@ubuntu2004:~/projects/gitlab/dechin/src/mindspore$ python3 store_data_to_db.py
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 29504.11it/s]
(base) dechin@ubuntu2004:~/projects/gitlab/dechin/src/mindspore$ ll | grep xyz.db
-rw-r--r-- 1 dechin dechin 8192 6月 1 15:43 xyz.db
運行完成後,會在當前目錄下產生一個名為xyz.db
的資料庫文件,在可遷移性上是比較靈活的。需要特別提及的是,這裡我們不僅存儲了x,y,z這3個變數,同時也存儲了index數據,這樣方便我們對數據進行檢索和查找。在程式的最後一步,一定要執行commit
才能夠將數據保存到資料庫文件中,否則不會被保存。
資料庫文件的讀取
接著上一個章節的內容,我們用Ipython來測試一下是否成功的將數據寫入到了資料庫文件中(這裡number
是表單的名字):
(base) dechin@ubuntu2004:~/projects/gitlab/dechin/src/mindspore$ ipython
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.19.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import sqlite3
In [2]: conn = sqlite3.connect('xyz.db')
In [3]: cur = conn.cursor()
In [4]: cur.execute('SELECT * FROM number WHERE i=0')
Out[4]: <sqlite3.Cursor at 0x7fd08bd5cc70>
In [5]: print (cur.fetchall())
[(0, 0.0099305893254821, -0.003805282402773131, 5.014158221453069)]
In [6]: cur.execute('SELECT * FROM number WHERE i=99')
Out[6]: <sqlite3.Cursor at 0x7fd08bd5cc70>
In [7]: print (cur.fetchall())
[(99, 0.1408058052492868, -0.5207606243222331, 3.7101686456005116)]
In [8]: cur.execute('SELECT * FROM number WHERE i=100')
Out[8]: <sqlite3.Cursor at 0x7fd08bd5cc70>
In [9]: print (cur.fetchall())
[]
In [10]: cur.close()
...: conn.close()
在這個案例中我們可以看到,成功的讀取了第0個數據和第99個數據,如果超過這個範圍去檢索,會返回一個空的值。返回的結果是被包在一個list中的tuple,所以注意讀取的方式要用cur.fetchall()[0][0]
才能夠讀取到這一列中的第一個元素。
與MindSpore的結合
在介紹完數據的產生和存儲、資料庫文件的讀取兩個工作後,結合起來我們可以嘗試從資料庫文件中去載入訓練數據,用於MindSpore的模型訓練。這裡我們不展開去介紹MindSpore的模型和程式碼,在前面的這一篇部落格中有介紹相關的細節,讓我們直接看一下程式碼:
# dataset_test.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
import numpy as np
from mindspore import dataset as ds
from mindspore import nn, Tensor, Model
import time
from mindspore.train.callback import Callback, LossMonitor
import sqlite3
conn = sqlite3.connect('xyz.db')
cur = conn.cursor()
def get_data(num, a=2.0, b=3.0, c=5.0):
for _ in range(num):
cur.execute('SELECT * FROM number WHERE i={}'.format(_))
data = cur.fetchall()[0]
yield np.array([[float(data[1])],
[float(data[2])]],dtype=np.float32).reshape(1,2), np.array([float(data[3])]).astype(np.float32)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['xy','z'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 100
batch_number = 10
repeat_number = 20
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(2, 1, 0.02, 0.02)
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
print ('Param Shape is: {}'.format(len(model_params)))
for net_param in net.trainable_params():
print(net_param, net_param.asnumpy())
net_loss = nn.loss.MSELoss()
optim = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.6)
model = Model(net, net_loss, optim)
epoch = 1
model.train(epoch, ds_train, callbacks=[LossMonitor(10)], dataset_sink_mode=False)
for net_param in net.trainable_params():
print(net_param, net_param.asnumpy())
cur.close()
conn.close()
跟前面的部落格中類似的是,我們還是用了MindSpore的GeneratorDataset
這個方法來構造數據,並且通過get_data
函數逐個的返回資料庫中對應位置的數據。可以看一下是否能夠成功訓練:
(base) dechin@ubuntu2004:~/projects/gitlab/dechin/src/mindspore$ singularity exec --nv /home/dechin/tools/singularity/mindspore-gpu_1.2.0.sif python dataset_test.py
Param Shape is: 2
Parameter (name=fc.weight, shape=(1, 2), dtype=Float32, requires_grad=True) [[0.02 0.02]]
Parameter (name=fc.bias, shape=(1,), dtype=Float32, requires_grad=True) [0.02]
epoch: 1 step: 10, loss is 15.289665
epoch: 1 step: 20, loss is 4.292768
epoch: 1 step: 30, loss is 2.199254
epoch: 1 step: 40, loss is 0.558127
epoch: 1 step: 50, loss is 1.2218236
epoch: 1 step: 60, loss is 2.0977945
epoch: 1 step: 70, loss is 2.0961792
epoch: 1 step: 80, loss is 1.107859
epoch: 1 step: 90, loss is 1.1687267
epoch: 1 step: 100, loss is 1.166467
epoch: 1 step: 110, loss is 0.73308593
epoch: 1 step: 120, loss is 1.2287892
epoch: 1 step: 130, loss is 1.2843382
epoch: 1 step: 140, loss is 1.996727
epoch: 1 step: 150, loss is 1.9126663
epoch: 1 step: 160, loss is 1.1095876
epoch: 1 step: 170, loss is 1.1662093
epoch: 1 step: 180, loss is 2.1144183
epoch: 1 step: 190, loss is 1.6211499
epoch: 1 step: 200, loss is 2.0198507
Parameter (name=fc.weight, shape=(1, 2), dtype=Float32, requires_grad=True) [[1.0131103 0.27144054]]
Parameter (name=fc.bias, shape=(1,), dtype=Float32, requires_grad=True) [5.3248053]
訓練完成,雖然我們看到最終擬合出來的數據效果不是很好,但是從流程上來說我們已經達成了通過資料庫格式的數據來構造MindSpore的訓練數據輸入的目的。
總結概要
本文按照數據流的順序,分別介紹了:使用sqlite3資料庫存儲數據、從sqlite3資料庫中讀取數據、使用從sqlite3資料庫中的數據構造MindSpore可識別的訓練數據集。對於輸入的數據量比較大的場景,我們不太可能將全部的數據都載入到記憶體中,這就要考慮各種可以快速存儲和讀取的方案,資料庫就是一種比較常見的方案。而sqlite3作為一款非常輕量級的資料庫,在大部分的Python3中都是內置的,省去了很多編譯安裝的繁瑣。當然性能表現可能不如其他的資料庫,但是在我們這邊給定的場景下,表現還是非常優秀的!
版權聲明
本文首發鏈接為://www.cnblogs.com/dechinphy/p/ms-sql.html
作者ID:DechinPhy
更多原著文章請參考://www.cnblogs.com/dechinphy/
打賞專用鏈接://www.cnblogs.com/dechinphy/gallery/image/379634.html
騰訊雲專欄同步://cloud.tencent.com/developer/column/91958