from sklearn import datasets
digits = datasets.load_digits()
X = digits.data
y = digits.target
noisy_digits = X + np.random.normal(0,4,size=X.shape)
example_digits = noisy_digits[y==0,:][:10]
for num in range(1,10):
X_num = noisy_digits[y==num,:][:10]
example_digits = np.vstack([example_digits,X_num])
example_digits.shape
(100, 64)
def plot_digits(data):
fig, axes = plt.subplots(10,10,figsize=(10,10),subplot_kw={'xticks':[],'yticks':[]},
gridspec_kw = dict(hspace=0.1,wspace=0.1))
for i,ax in enumerate(axes.flat):
ax.imshow(data[i].reshape(8,8),
cmap='binary',interpolation='nearest',clim=(0,16))
plt.show()