Foundations

Data Visualization for Analysis

Lesson 3 of 4 Estimated Time 40 min

Data Visualization for Analysis

Data visualization is your first line of defense against bad models. A single well-crafted plot can reveal patterns that would take hours to find in a spreadsheet. More importantly, it can immediately show you when something is wrong—a data quality issue, a data leak, or a faulty assumption about your domain.

Matplotlib: The Foundation

Matplotlib is the foundational visualization library in Python. While it can be verbose, understanding it gives you complete control:

import matplotlib.pyplot as plt
import numpy as np

# Create a simple line plot
x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, linewidth=2, label='sin(x)')
plt.xlabel('X values', fontsize=12)
plt.ylabel('Y values', fontsize=12)
plt.title('Simple Sine Wave', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

Multiple plots in one figure (subplots) are essential for comparison:

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot in different positions
axes[0, 0].plot(x, np.sin(x), label='sin(x)')
axes[0, 0].set_title('Sine')
axes[0, 0].legend()

axes[0, 1].plot(x, np.cos(x), label='cos(x)', color='orange')
axes[0, 1].set_title('Cosine')
axes[0, 1].legend()

axes[1, 0].plot(x, x**2, label='x²', color='green')
axes[1, 0].set_title('Quadratic')
axes[1, 0].legend()

axes[1, 1].scatter(np.random.randn(100), np.random.randn(100), alpha=0.6)
axes[1, 1].set_title('Random Scatter')

plt.tight_layout()
plt.show()

Distributions: Understanding Your Data

When you first load data, always visualize distributions:

import matplotlib.pyplot as plt
import numpy as np

# Generate sample data
data = np.random.normal(loc=100, scale=15, size=10000)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Histogram
axes[0].hist(data, bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0].set_title('Histogram')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')

# Density plot (smoothed histogram)
axes[1].hist(data, bins=50, density=True, alpha=0.5, color='blue')
from scipy import stats
density_x = np.linspace(data.min(), data.max(), 100)
axes[1].plot(density_x, stats.norm.pdf(density_x, data.mean(), data.std()))
axes[1].set_title('Histogram with KDE')
axes[1].set_ylabel('Density')

# Box plot (shows quartiles and outliers)
axes[2].boxplot(data, vert=True)
axes[2].set_title('Box Plot')
axes[2].set_ylabel('Value')

plt.tight_layout()
plt.show()

# Statistics revealed by the plot
print(f"Mean: {data.mean():.2f}")
print(f"Median: {np.median(data):.2f}")
print(f"Std: {data.std():.2f}")
print(f"Q1: {np.percentile(data, 25):.2f}")
print(f"Q3: {np.percentile(data, 75):.2f}")

Seaborn: Statistical Visualization Made Easy

Seaborn builds on Matplotlib with beautiful defaults and statistical capabilities:

import seaborn as sns
import pandas as pd

# Set style for better-looking plots
sns.set_theme(style="whitegrid")

# Load example dataset
df = sns.load_dataset('tips')  # Sample dataset
print(df.head())

# Distribution plots
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram with KDE overlay
sns.histplot(df['total_bill'], kde=True, ax=axes[0])
axes[0].set_title('Distribution of Total Bill')

# Violin plot (combines distribution with quartiles)
sns.violinplot(data=df, x='day', y='total_bill', ax=axes[1])
axes[1].set_title('Total Bill by Day')

plt.tight_layout()
plt.show()

Relationships: Correlation and Dependence

Understanding relationships between variables is crucial for feature engineering:

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

df = sns.load_dataset('tips')

# Scatter plot with regression line
plt.figure(figsize=(10, 6))
sns.regplot(data=df, x='total_bill', y='tip', scatter_kws={'alpha': 0.5})
plt.title('Relationship: Total Bill vs Tip')
plt.show()

# Correlation matrix heatmap
fig, ax = plt.subplots(figsize=(8, 6))
numeric_df = df.select_dtypes(include=['float64', 'int64'])
corr = numeric_df.corr()

sns.heatmap(corr, annot=True, fmt='.2f', cmap='coolwarm', center=0, ax=ax)
plt.title('Feature Correlation Matrix')
plt.show()

# Joint plot (scatter + marginal distributions)
sns.jointplot(data=df, x='total_bill', y='tip', kind='hex')
plt.show()

For categorical relationships:

# Box plot: numeric vs categorical
sns.boxplot(data=df, x='day', y='total_bill', hue='sex')
plt.title('Total Bill by Day and Gender')
plt.show()

# Count plot: categorical frequencies
sns.countplot(data=df, x='day', hue='sex')
plt.title('Counts of Observations')
plt.show()

# Cross-tabulation
crosstab = pd.crosstab(df['day'], df['sex'])
sns.heatmap(crosstab, annot=True, fmt='d', cmap='YlOrRd')
plt.title('Cross-tabulation: Day vs Sex')
plt.show()

Time Series Visualization

Time series data requires special visualization approaches:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Create synthetic time series
dates = pd.date_range('2023-01-01', periods=365, freq='D')
values = np.cumsum(np.random.randn(365)) + 100

df = pd.DataFrame({'date': dates, 'value': values})

# Line plot with date axis
plt.figure(figsize=(14, 5))
plt.plot(df['date'], df['value'], linewidth=1.5)
plt.xlabel('Date')
plt.ylabel('Value')
plt.title('Time Series Over Year')
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Rolling mean (smoothing)
df['rolling_mean'] = df['value'].rolling(window=30).mean()

plt.figure(figsize=(14, 5))
plt.plot(df['date'], df['value'], alpha=0.5, label='Original')
plt.plot(df['date'], df['rolling_mean'], linewidth=2, label='30-day MA')
plt.xlabel('Date')
plt.ylabel('Value')
plt.legend()
plt.title('Time Series with Rolling Average')
plt.tight_layout()
plt.show()

# Seasonal decomposition
from statsmodels.tsa.seasonal import seasonal_decompose

decomposition = seasonal_decompose(df['value'], model='additive', period=30)

fig, axes = plt.subplots(4, 1, figsize=(14, 10))
decomposition.observed.plot(ax=axes[0], title='Observed')
decomposition.trend.plot(ax=axes[1], title='Trend')
decomposition.seasonal.plot(ax=axes[2], title='Seasonal')
decomposition.resid.plot(ax=axes[3], title='Residual')
plt.tight_layout()
plt.show()

Effective Visualization Principles

Good visualizations have principles worth following:

# Bad visualization: too much information, poor defaults
fig, ax = plt.subplots()
ax.scatter(data_x, data_y, s=10, alpha=0.3, c=range(len(data_x)))
ax.set_title('Data')
ax.grid(True, which='both')
ax.set_xlabel('x')
ax.set_ylabel('y')

# Better visualization: clear purpose, good design
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(data_x, data_y, s=50, alpha=0.6, color='steelblue', edgecolors='navy', linewidth=0.5)
ax.set_title('Distribution of Feature X vs Target Y', fontsize=14, fontweight='bold')
ax.set_xlabel('Feature X Value', fontsize=12)
ax.set_ylabel('Target Y Value', fontsize=12)
ax.grid(True, alpha=0.3, linestyle='--')

# Add context: mean lines
ax.axhline(data_y.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean Y: {data_y.mean():.2f}')
ax.axvline(data_x.mean(), color='orange', linestyle='--', linewidth=2, label=f'Mean X: {data_x.mean():.2f}')
ax.legend()

plt.tight_layout()
plt.show()

Key principles:

  1. Clear purpose: Every plot should answer a specific question
  2. Appropriate type: Use scatter for relationships, bars for categories, histograms for distributions
  3. Readable labels: Always label axes and provide a descriptive title
  4. Color wisely: Use color to highlight important information, not for decoration
  5. Minimize clutter: Remove gridlines, legends, and annotations that don’t add value

Practical Visualization for Model Debugging

When building ML models, specific plots reveal problems:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

# Create sample regression problem
X = np.random.randn(200, 1)
y = 2 * X.squeeze() + np.random.randn(200) * 0.5

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

model = LinearRegression()
model.fit(X_train, y_train)

y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)

# Prediction vs actual
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].scatter(y_train, y_pred_train, alpha=0.5, label='Train')
axes[0].scatter(y_test, y_pred_test, alpha=0.5, label='Test')
axes[0].plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2, label='Perfect')
axes[0].set_xlabel('Actual')
axes[0].set_ylabel('Predicted')
axes[0].legend()
axes[0].set_title('Predictions vs Actual')

# Residuals (errors)
residuals_train = y_train - y_pred_train
residuals_test = y_test - y_pred_test

axes[1].scatter(y_pred_train, residuals_train, alpha=0.5, label='Train')
axes[1].scatter(y_pred_test, residuals_test, alpha=0.5, label='Test')
axes[1].axhline(0, color='k', linestyle='--', lw=2)
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Residuals')
axes[1].legend()
axes[1].set_title('Residual Plot (should be random)')

plt.tight_layout()
plt.show()

# Check for patterns in residuals (bias, heteroscedasticity)
print(f"Mean residual (train): {residuals_train.mean():.4f}")
print(f"Mean residual (test): {residuals_test.mean():.4f}")

Key Takeaway

Great visualizations aren’t about making pretty pictures—they’re about seeing your data clearly. Before you build your first model, spend time visualizing your data from every angle. This single habit will catch more bugs, suggest better features, and build your intuition for the problem than any other activity.

Practical Exercise

You have a dataset with customer transactions. Create a comprehensive visualization dashboard that:

  1. Shows the distribution of transaction amounts (histogram + box plot)
  2. Displays transaction volume over time (line plot with rolling average)
  3. Reveals patterns by customer segment (violin plot or box plot)
  4. Shows correlation between spending and customer age (scatter with regression line)
  5. Displays top product categories by revenue (bar chart)

Create a function that generates this 5-subplot dashboard:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_transaction_dashboard(df):
    """
    Create a 2x3 dashboard visualizing transaction data.

    df should contain columns:
    - transaction_amount: Amount of transaction
    - transaction_date: Date of transaction
    - customer_segment: Customer category
    - customer_age: Customer age
    - product_category: Product category
    """
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))

    # Your implementation here
    # 1. axes[0, 0]: Distribution of transaction amounts
    # 2. axes[0, 1]: Transaction volume over time
    # 3. axes[0, 2]: Spending by customer segment
    # 4. axes[1, 0]: Age vs spending relationship
    # 5. axes[1, 1]: Top product categories
    # 6. axes[1, 2]: Your choice of insight

    plt.tight_layout()
    return fig

# Test with sample data
dates = pd.date_range('2023-01-01', periods=1000, freq='D')
df = pd.DataFrame({
    'transaction_date': np.random.choice(dates, 500),
    'transaction_amount': np.random.exponential(50, 500) + 10,
    'customer_segment': np.random.choice(['Premium', 'Standard', 'Budget'], 500),
    'customer_age': np.random.randint(18, 80, 500),
    'product_category': np.random.choice(['Electronics', 'Clothing', 'Home', 'Sports'], 500)
})

fig = visualize_transaction_dashboard(df)
plt.show()

Focus on clear, informative visualizations that would actually help a business stakeholder understand the data.