tSNE降维 样例代码

tSNE降维 样例代码

import numpy as np

from sklearn.manifold import TSNE
# For the UCI ML handwritten digits dataset
from sklearn.datasets import load_digits

# Import matplotlib for plotting graphs ans seaborn for attractive graphics.
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import seaborn as sns

def plot(x, colors):
    # Choosing color palette
    # https://seaborn.pydata.org/generated/seaborn.color_palette.html
    palette = np.array(sns.color_palette("pastel", 10))
    # pastel, husl, and so on

    # Create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40, c=palette[colors.astype(np.int8)])
    # Add the labels for each digit.
    txts = []
    for i in range(10):
        # Position of each label.
        xtext, ytext = np.median(x[colors == i, :], axis=0)
        txt = ax.text(xtext, ytext, str(i), fontsize=24)
        txt.set_path_effects([pe.Stroke(linewidth=5, foreground="w"), pe.Normal()])
        txts.append(txt)
    plt.savefig('./digits_tsne-pastel.png', dpi=120)
    return f, ax, txts


digits = load_digits()
print(digits.data.shape)
# There are 10 classes (0 to 9) with alomst 180 images in each class 
# The images are 8x8 and hence 64 pixels(dimensions)

# Place the arrays of data of each digit on top of each other and store in X
X = np.vstack([digits.data[digits.target==i] for i in range(10)])
# Place the arrays of data of each target digit by the side of each other continuosly and store in Y
Y = np.hstack([digits.target[digits.target==i] for i in range(10)])

# Implementing the TSNE Function - ah Scikit learn makes it so easy!
digits_final = TSNE(perplexity=30).fit_transform(X) 
# Play around with varying the parameters like perplexity, random_state to get different plots

plot(digits_final, Y)



def plot2(data, x='x', y='y'):
    sns.set_context("notebook", font_scale=1.1)
    sns.set_style("ticks")

    sns.lmplot(x=x,
            y=y,
            data=data,
            fit_reg=False,
            legend=True,
            height=9,
            hue='Label',
            scatter_kws={"s":200, "alpha":0.3})

    plt.title('t-SNE Results: Digits', weight='bold').set_fontsize('14')
    plt.xlabel(x, weight='bold').set_fontsize('10')
    plt.ylabel(y, weight='bold').set_fontsize('10')
    plt.savefig('./digits_tsne-plot2.png', dpi=120)

import pandas as pd
data = {'x': digits_final[:, 0],
        'y': digits_final[:, 1],
        'Label': Y}
data = pd.DataFrame(data)
plot2(data)


结果如下所示:

在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值