MindSpore函數擬合
技術背景
在前面一篇部落格中我們介紹過基於docker的mindspore編程環境配置,這裡我們基於這個環境,使用mindspore來擬合一個線性的函數,演示一下mindspore的基本用法。
環境準備
在Manjaro Linux上先用如下命令啟動docker容器服務,啟動後可用status查看狀態:
[dechin-manjaro mindspore]# systemctl start docker
[dechin-manjaro mindspore]# systemctl status docker
● docker.service - Docker Application Container Engine
Loaded: loaded (/usr/lib/systemd/system/docker.service; disabled; vendor preset: disabled)
Active: active (running) since Wed 2021-04-14 16:32:38 CST; 9s ago
TriggeredBy: ● docker.socket
Docs: //docs.docker.com
Main PID: 298485 (dockerd)
Tasks: 99 (limit: 47875)
Memory: 186.0M
CGroup: /system.slice/docker.service
├─298485 /usr/bin/dockerd -H fd://
└─298496 containerd --config /var/run/docker/containerd/containerd.toml --log-level info
在按照這篇部落格的方法下載下來mindspore的容器鏡像之後,可以在本地的鏡像倉庫中查詢到該鏡像:
[dechin-root mindspore]# docker images
REPOSITORY TAG IMAGE ID
swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4
容器的啟動方式可以參考如下指令:
[dechin-root mindspore]# docker run -it 98a3
root@2a6c33894e53:~# python
Python 3.7.5 (default, Feb 8 2021, 02:21:05)
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>
這裡可以看到在這個容器鏡像中是預裝了python3.7.5版本的mindspore的,可以在python的命令行中用如下的方法進行驗證:
>>> from mindspore import context
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(20:139876984823936,MainProcess):2021-04-14-08:37:40.331.840 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
除了mindspore自身之外,我們還經常可能用到一些第三方的庫,如matplotlib等,我們可以自行安裝:
root@2a6c33894e53:~# python -m pip install matplotlib
Looking in indexes: //mirrors.aliyun.com/pypi/simple/
Collecting matplotlib
Downloading //mirrors.aliyun.com/pypi/packages/ce/63/74c0b6184b6b169b121bb72458818ee60a7d7c436d7b1907bd5874188c55/matplotlib-3.4.1-cp37-cp37m-manylinux1_x86_64.whl (10.3MB)
|████████████████████████████████| 10.3MB 4.4MB/s
Collecting cycler>=0.10 (from matplotlib)
Downloading //mirrors.aliyun.com/pypi/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
Collecting kiwisolver>=1.0.1 (from matplotlib)
Downloading //mirrors.aliyun.com/pypi/packages/d2/46/231de802ade4225b76b96cffe419cf3ce52bbe92e3b092cf12db7d11c207/kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1MB)
|████████████████████████████████| 1.1MB 13.9MB/s
Collecting python-dateutil>=2.7 (from matplotlib)
Downloading //mirrors.aliyun.com/pypi/packages/d4/70/d60450c3dd48ef87586924207ae8907090de0b306af2bce5d134d78615cb/python_dateutil-2.8.1-py2.py3-none-any.whl (227kB)
|████████████████████████████████| 235kB 4.6MB/s
Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (2.4.7)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (8.1.0)
Requirement already satisfied: numpy>=1.16 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (1.17.5)
Requirement already satisfied: six in /usr/local/python-3.7.5/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.15.0)
Installing collected packages: cycler, kiwisolver, python-dateutil, matplotlib
Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1 python-dateutil-2.8.1
WARNING: You are using pip version 19.2.3, however version 21.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
root@2a6c33894e53:~# python -m pip install --upgrade pip
Looking in indexes: //mirrors.aliyun.com/pypi/simple/
Collecting pip
Downloading //mirrors.aliyun.com/pypi/packages/fe/ef/60d7ba03b5c442309ef42e7d69959f73aacccd0d86008362a681c4698e83/pip-21.0.1-py3-none-any.whl (1.5MB)
|████████████████████████████████| 1.5MB 1.3MB/s
Installing collected packages: pip
Found existing installation: pip 19.2.3
Uninstalling pip-19.2.3:
Successfully uninstalled pip-19.2.3
Successfully installed pip-21.0.1
同樣的方法我們再安裝一下ipython:
root@b8955ba28950:/home# python -m pip install IPython
Looking in indexes: //mirrors.aliyun.com/pypi/simple/
Collecting IPython
Downloading //mirrors.aliyun.com/pypi/packages/c9/b1/82cbe2b856386f44f37fdae54d9b425813bd86fe33385c9d658d64826098/ipython-7.22.0-py3-none-any.whl (785 kB)
|████████████████████████████████| 785 kB 1.8 MB/s
Collecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0
Downloading //mirrors.aliyun.com/pypi/packages/eb/e6/4b4ca4fa94462d4560ba2f4e62e62108ab07be2e16a92e594e43b12d3300/prompt_toolkit-3.0.18-py3-none-any.whl (367 kB)
|████████████████████████████████| 367 kB 818 kB/s
Collecting pickleshare
Downloading //mirrors.aliyun.com/pypi/packages/9a/41/220f49aaea88bc6fa6cba8d05ecf24676326156c23b991e80b3f2fc24c77/pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)
Collecting pygments
Downloading //mirrors.aliyun.com/pypi/packages/3a/80/a52c0a7c5939737c6dca75a831e89658ecb6f590fb7752ac777d221937b9/Pygments-2.8.1-py3-none-any.whl (983 kB)
|████████████████████████████████| 983 kB 2.7 MB/s
Requirement already satisfied: decorator in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (4.4.2)
Collecting traitlets>=4.2
Downloading //mirrors.aliyun.com/pypi/packages/f6/7d/3ecb0ebd0ce8dcdfa7bd47ab85c1d4a521e6770ef283d0824f5804994dfe/traitlets-5.0.5-py3-none-any.whl (100 kB)
|████████████████████████████████| 100 kB 4.0 MB/s
Collecting pexpect>4.3
Downloading //mirrors.aliyun.com/pypi/packages/39/7b/88dbb785881c28a102619d46423cb853b46dbccc70d3ac362d99773a78ce/pexpect-4.8.0-py2.py3-none-any.whl (59 kB)
|████████████████████████████████| 59 kB 5.9 MB/s
Collecting jedi>=0.16
Downloading //mirrors.aliyun.com/pypi/packages/f9/36/7aa67ae2663025b49e8426ead0bad983fee1b73f472536e9790655da0277/jedi-0.18.0-py2.py3-none-any.whl (1.4 MB)
|████████████████████████████████| 1.4 MB 3.7 MB/s
Collecting backcall
Downloading //mirrors.aliyun.com/pypi/packages/4c/1c/ff6546b6c12603d8dd1070aa3c3d273ad4c07f5771689a7b69a550e8c951/backcall-0.2.0-py2.py3-none-any.whl (11 kB)
Requirement already satisfied: setuptools>=18.5 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (41.2.0)
Collecting parso<0.9.0,>=0.8.0
Downloading //mirrors.aliyun.com/pypi/packages/a9/c4/d5476373088c120ffed82f34c74b266ccae31a68d665b837354d4d8dc8be/parso-0.8.2-py2.py3-none-any.whl (94 kB)
|████████████████████████████████| 94 kB 6.0 MB/s
Collecting ptyprocess>=0.5
Downloading //mirrors.aliyun.com/pypi/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)
Collecting wcwidth
Downloading //mirrors.aliyun.com/pypi/packages/59/7c/e39aca596badaf1b78e8f547c807b04dae603a433d3e7a7e04d67f2ef3e5/wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)
Collecting ipython-genutils
Downloading //mirrors.aliyun.com/pypi/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl (26 kB)
Installing collected packages: wcwidth, ptyprocess, parso, ipython-genutils, traitlets, pygments, prompt-toolkit, pickleshare, pexpect, jedi, backcall, IPython
WARNING: The script pygmentize is installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
WARNING: The scripts iptest, iptest3, ipython and ipython3 are installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
Successfully installed IPython-7.22.0 backcall-0.2.0 ipython-genutils-0.2.0 jedi-0.18.0 parso-0.8.2 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.18 ptyprocess-0.7.0 pygments-2.8.1 traitlets-5.0.5 wcwidth-0.2.5
安裝過程中都沒有出現其他的依賴問題,接下來我們可以在docker容器中保存這些已經安裝的庫,避免下一次使用的時候還需要再安裝一次。在用exit
退出當前容器鏡像之後,可以用docker ps
指令查看近期的操作記錄:
[dechin-root mindspore]# docker ps -n 3
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
2a6c33894e53 98a3 "/bin/bash" 13 minutes ago Exited (0) 7 seconds ago upbeat_tharp
625ee5f4ee95 ea1c "bash" 9 days ago Exited (0) 9 days ago zealous_mccarthy
ded2cb29290a kivy/buildozer "buildozer bash -c '…" 9 days ago Exited (1) 9 days ago exciting_lumiere
這裡第一個操作記錄就是我們需要保存的mindspore的鏡像,那麼我們可以用docker commit
的指令將操作保存到一個新的鏡像裡面:
[dechin-root mindspore]# docker commit 2a6c mindspore
sha256:3a6951d9b9009f93027748ecec78078efff1fb36599a5786bcbc667e72119392
上面的執行回饋表示運行成功了,再次查看本地鏡像內容:
[dechin-root mindspore]# docker images
REPOSITORY TAG IMAGE ID CREATED SIZE
mindspore latest 3a6951d9b900 31 seconds ago 1.22GB
swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4 2 months ago 1.18GB
可以看到我們的基礎鏡像環境已經製作完成了,在原鏡像的基礎上多了40M左右的空間。本章節的最後我們也說明一下,mindspore提供的這個鏡像的基礎系統環境為Ubuntu18.04
:
root@b8955ba28950:/home# cat /etc/issue
Ubuntu 18.04.5 LTS \n \l
MindSpore線性函數擬合
假設有如下圖中紅點所示的一系列散點,或者可以認為是需要我們來執行訓練的數據。而圖中的綠線表示真實的函數,也就是說我們是基於這樣一個真實的線性函數,來生成了一系列加隨機雜訊的散點。最終我們的目的當然是希望能夠通過這些散點將線性的函數再擬合出來,這樣就可以用來預測下一個位置的函數值,相關技術用在量化金融領域,就可以預測下一步股市的價格,當然那樣的函數就會更加的複雜。
對應於圖中的函數,我們給定的是:
\]
生成散點數據集
加雜訊的方法在get_data
函數中體現,其中生成數據集的方法為:先在\([-10,10]\)的範圍內生成一系列的隨機\(x\)自變數值,然後生成一系列的正態分布隨機數作為雜訊,把這些雜訊加到自變數值所對應的\(f(x)\)函數值上,就得到了原始數據。當然,這裡沒有用return
進行返回,而是用yield
的形式逐一返回。
第二步我們需要將這些數據集轉化為mindspore所能夠識別的數據格式:mindspore.dataset.GeneratorDataset
,除了可以給\(x\)和\(y\)分別配置一個變數名之外,還可以指定這些數據集的分組(batch)和重複次數,其中分組數量的配置是有可能影響到最終的訓練速率的。
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50)) # 生成50個帶雜訊的隨機點
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3 # 期望的函數值
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
print("The dataset size of ds_train:", ds_train.get_dataset_size())
dict_datasets = next(ds_train.create_dict_iterator())
print(dict_datasets.keys())
print("The x label value shape:", dict_datasets["data"].shape)
print("The y label value shape:", dict_datasets["label"].shape)
上述程式碼的執行效果如下:
root@b8955ba28950:/home# python test_linear.py
The dataset size of ds_train: 100
dict_keys(['data', 'label'])
The x label value shape: (16, 1)
The y label value shape: (16, 1)
到這裡為止,我們就已經構造了一個1600個訓練的數據,並且分為了100個batch進行訓練,每個batch的大小為16。
構建擬合模型與初始參數
用mindspore.nn.Dense
的方法我們可以構造一個線性擬合的模型:
\]
關於該激活函數的官方文檔說明如下:
而這裡面的weight
和bias
的初始化參數是由一個張量形式的數據結構來定義的,我們給了一個入參nn.Dense(1, 1, Normal(0.02), Normal(0.02))
表示兩組參數,都是一維的張量(或稱為1階的張量),而這兩個初始化張量的元素是由兩個\(N(0,\sigma)\)正態分布所生成的隨機化初始數據,比如在該案例中我們可以試著將這些初始化的參數列印出來:
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
for param in model_params:
print(param, param.asnumpy())
執行結果如下,是兩個一維的數組:數組
root@b8955ba28950:/home# python test_linear.py
Parameter (name=fc.weight) [[-0.00252427]]
Parameter (name=fc.bias) [0.00694926]
在上述程式碼中雖然列印了兩個參數值,但是並不是很直觀,我們可以將這組參數值所對應的函數圖畫在剛才的散點圖中看看效果:
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn, Tensor
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
Tensor(model_params[1]).asnumpy()[0])
plt.axis([-10, 10, -20, 25])
plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_model_label, y_model_label, color="blue")
plt.plot(x_target_label, y_target_label, color="green")
plt.savefig('initial.png')
執行後會在當前目錄生成一個名為initial.png
的圖片:
可以看到此時的參數所對應的函數距離我們所預期的還是比較遠的。
訓練與可視化
在前面的技術鋪墊之後,這一步終於可以開始訓練了。在機器學習中,我們需要先定義好一個用于衡量結果好壞的函數,一般可以稱之為損失函數(Loss Function)。損失函數值越小,代表結果就越好,在我們面對的這個函數擬合問題中所代表的就是,擬合的效果越好。這裡我們採取的是均方誤差函數(Mean Square Error,簡稱MSE):
均方誤差是最常使用的損失函數,因為不管是往哪個方向的偏移,都會導致損失函數值的急劇增大。在定義好損失函數之後,我們需要定義一個前向傳播網路,用於執行損失函數的計算,這裡我們直接使用了mindspore定義好的介面:mindspore.nn.loss.MSELoss
:
在計算好對應參數的損失函數值之後,我們需要更新迭代參數,計算下一組參數的損失函數值,以確定向哪個方向「前進」才能找到最終的最低損失函數值。這個參數迭代的功能由反向傳播網路實現,常用的參數更新演算法有梯度下降等,關於梯度下降演算法,在前面寫過的這篇部落格中有比較詳細的介紹。其基本計算公式如下:
在mindspore中優化函數的介面為mindspore.nn.Momentum
:
這些模型都定義好之後,可以用mindspore.Model
進行封裝和訓練。
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn, Tensor, Model
import time
from IPython import display
from mindspore.train.callback import Callback, LossMonitor
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
Tensor(model_params[1]).asnumpy()[0])
net = LinearNet()
net_loss = nn.loss.MSELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
model = Model(net, net_loss, opt)
fig = plt.figure()
ims = []
def plot_model_and_datasets(net, eval_data):
weight = net.trainable_params()[0]
bias = net.trainable_params()[1]
x = np.arange(-10, 10, 0.1)
y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]
x1, y1 = zip(*eval_data)
x_target = x
y_target = x_target * 2 + 3
plt.axis([-11, 11, -20, 25])
plt.scatter(x1, y1, color="red", s=5)
im = plt.plot(x, y, color="blue")
ims.append(im)
im1 = plt.plot(x_target, y_target, color="green")
ims.append(im1)
time.sleep(0.2)
class ImageShowCallback(Callback):
def __init__(self, net, eval_data):
self.net = net
self.eval_data = eval_data
def step_end(self, run_context):
plot_model_and_datasets(self.net, self.eval_data)
display.clear_output(wait=True)
epoch = 1
imageshow_cb = ImageShowCallback(net, eval_data)
model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)
plot_model_and_datasets(net, eval_data)
for net_param in net.trainable_params():
print(net_param, net_param.asnumpy())
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
ani.save('train.gif', writer='pillow')
執行結果如下:
root@b8955ba28950:/home# python test_linear.py
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(444:140374496206976,MainProcess):2021-04-14-09:28:58.738.627 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
Parameter (name=fc.weight) [[1.8964282]]
Parameter (name=fc.bias) [3.0266616]
執行完成後會在當前目錄下生成一個名為train.gif
的動態圖,演示整個訓練優化的過程:
其中紅色散點是訓練數據,綠色直線是原始函數,藍色直線是訓練後的函數,可以看到兩個函數是越來越接近的。最後擬合出來的函數為:
\]
與我們所預期的:
\]
還是略有差距,但是這其中的可能原因有很多,有可能是生成的隨機散點的問題,也有可能是在這個範圍內的線段擬合就是有這麼大的誤差,這裡我們不做展開。到這裡為止,我們就成功的使用mindspore完成了一個函數擬合的任務。
python繪製動態函數圖
在上一個章節中我們演示了使用mindspore完成了一個線性函數的擬合,最後的程式碼中其實已經使用到了動態圖的繪製方法,這裡單獨抽取出來作為一個章節來介紹。我們所使用到的工具是matplotlib.animation
,使用的第一步是在訓練的外部先生成一個動態影像的對象:
fig = plt.figure()
ims = []
其中ims
是用於存儲每一幀的數據繪製內容。第二步是將訓練過程中需要變化的繪圖對象添加到ims
中:
im = plt.plot(x, y, color="blue")
ims.append(im)
im1 = plt.plot(x_target, y_target, color="green")
ims.append(im1)
最後根據繪製的圖的對象fig
和變化的影像集合ims
來生成一個動態圖並且保存到本地文件中:
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
ani.save('train.gif', writer='pillow')
關於animation.ArtistAnimation
的介面參數如下所示:
這裡每一幀之間的間隔時間我們定義為500ms
,重複播放1000次,基本可以認為是一直在重複播放的。最終的效果圖在上一個章節中已經做了展示,這裡就不再重複說明。需要注意的是,生成動態圖的過程會比較漫長,而且只有通過animation才能夠生成和保存gif
動態圖,直接通過plt.savefig
是無法保存為動態圖的。
總結概要
很多機器學習的演算法的基礎就是函數的擬合,這裡我們考慮的是其中一種最簡單也最常見的場景:線性函數的擬合,並且我們要通過mindspore來實現這個數據的訓練。通過構造均方誤差函數,配合前向傳播網路與反向傳播網路的使用,最終大體成功的擬合了給定的一個線性函數。文末我們還順帶介紹了使用matplotlib的animation來生成動態圖的功能,可視化的展現了整個訓練的過程。
版權聲明
本文首發鏈接為://www.cnblogs.com/dechinphy/p/linear.html
作者ID:DechinPhy
更多原著文章請參考://www.cnblogs.com/dechinphy/
參考鏈接
- //www.mindspore.cn/tutorial/training/zh-CN/master/quick_start/linear_regression.html
- //blog.csdn.net/clksjx/article/details/105720120
- //www.cnblogs.com/dechinphy/p/gradient.html