使用文本數據預測一個人的性格
- 2019 年 12 月 31 日
- 筆記
我們使用的用 邁爾斯布里格斯類型(MBTI人格)標註的數據集。

一共有4個維度,每個維度有兩個類型,所以常人的性格從MBTI指標來看,一共有16種性格。
讀取數據
mbti數據集中有兩個字段
- type: 性格類型
- posts: 每個用戶的最近的50條推文,推文與推文之間用
|||
間隔開
先查看前5行數據
import pandas as pd import warnings warnings.filterwarnings('ignore') df = pd.read_csv('data/mbti.csv') df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 8675 entries, 0 to 8674 Data columns (total 2 columns): type 8675 non-null object posts 8675 non-null object dtypes: object(2) memory usage: 135.7+ KB
mbti數據集一共有8675條數據
數據探索
這裡我計算出每個推文的長度(沒啥大用,複習apply和seaborn可視化)
df['words_per_comment'] = df['posts'].apply(lambda x: len(x.split()))/50 df['posts'] = df['posts'].apply(lambda x:x.lower()) df.head()

小提琴圖show一下各個性格的wordspercomment信息
import seaborn as sns import matplotlib.pyplot as plt #畫布設置及尺寸 sns.set(style='white', font_scale=1.5) plt.figure(figsize=(15, 10)) #繪製小提琴圖 sns.violinplot(x='type', y='words_per_comment', data=df, color='lightgray') #繪製分類三點圖,疊加到小提琴圖圖層上方 sns.stripplot(x='type', y='words_per_comment', data=df, size=2, jitter=True) #標題及y軸名 plt.title('The Violin Plot of Words Per Comment', size=18) plt.ylabel('Words Per Comment') #顯示 plt.show()

分割數據
將數據集分為訓練集和測試集
from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(df['posts'], df['type'], test_size=0.2, random_state=123)
文本向量化
機器不理解文本,需要先編碼為數字,這裡使用tfidf方法進行編碼。不熟悉的可以看看這個介紹
from sklearn.feature_extraction.text import TfidfVectorizer tfidf = TfidfVectorizer(stop_words='english') X_train = tfidf.fit_transform(X_train) X_test = tfidf.transform(X_test)
訓練模型及模型得分
這裡我選來三種模型,使用score得分評價模型表現
from sklearn.linear_model import LogisticRegression model1 = LogisticRegression() model1.fit(X_train, y_train) model1.score(X_test, y_test)
0.6357348703170029
from sklearn.linear_model import SGDClassifier model2 = SGDClassifier() model2.fit(X_train, y_train) model2.score(X_test, y_test)
0.6824207492795389
from sklearn.linear_model import Perceptron model3 = Perceptron() model3.fit(X_train, y_train) model3.score(X_test, y_test)
0.5994236311239193
找到的這個數據集標註的可能有問題,如果是經典的數據集,一般跑出來都能達到80+%的準確率。