# Lecture 20: Data Visualization Fundamentals with Matplotlib

## Learning Objectives

By the end of this lecture, you will be able to:
- Understand visualization principles and choose appropriate chart types
- Create basic plots with Matplotlib using the pyplot interface
- Build line plots to show trends over time
- Create bar charts to compare quantities across categories
- Generate scatter plots to reveal relationships between variables
- Produce histograms to understand distributions
- Customize plots with titles, labels, colors, and legends
- Use Seaborn for statistical visualizations
- Save plots to files for reports and presentations

**Prerequisites**: Lectures 12-13 (NumPy), Lectures 16-19 (pandas)

## Setup and Imports

Data visualization requires several libraries working together. Matplotlib is Python's foundational plotting library that provides comprehensive control over every element of a visualization. Seaborn builds on matplotlib to provide a high-level interface for statistical graphics with attractive default styling. We also need pandas for data manipulation and NumPy for numerical operations.

In [None]:
# Import visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns

# Import data manipulation libraries
import pandas as pd
import numpy as np

# Set display options
pd.set_option('display.precision', 2)

print("Libraries imported successfully!")
print(f"Matplotlib version: {plt.matplotlib.__version__}")

## Part 1: Introduction to Matplotlib and Data Visualization

### Why Visualize Data?

Humans are visual creatures - we process images 60,000 times faster than text. A table of 1,000 numbers is incomprehensible, but a chart showing those same numbers reveals patterns instantly. Visualization is not decoration added after analysis - it's an integral part of analysis itself.

Visualization serves multiple purposes: exploration (understanding structure and spotting outliers), analysis (revealing relationships that statistics might miss), and communication (telling a story that numbers alone cannot tell).

### Choosing the Right Chart Type

Different chart types answer different questions:
- **Line plots**: Show trends over time or continuous variables
- **Bar charts**: Compare quantities across categories
- **Scatter plots**: Reveal relationships between two numeric variables
- **Histograms**: Show the distribution of a single numeric variable

### The Basic Plot

Creating a plot in matplotlib is straightforward. The plt.plot() function takes x and y values and draws a line connecting the points. After defining your plot, plt.show() displays it. This is the simplest way to create a visualization in Python.

In [None]:
# Create simple data
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]

# Create a basic line plot
plt.plot(x, y)
plt.show()

### Adding Titles and Labels

Every plot needs context to be meaningful. A title tells the viewer what the plot shows. Axis labels explain what each dimension represents. Without these elements, a plot is essentially meaningless to anyone who wasn't watching you create it. Always include title and axis labels in your visualizations.

In [None]:
# Create plot with proper labels
plt.figure(figsize=(8, 6))

plt.plot(x, y)

plt.title('Simple Linear Relationship')
plt.xlabel('X Values')
plt.ylabel('Y Values')
plt.grid(True)

plt.show()

### Figure Size and Resolution

The plt.figure() function creates a new figure with specified dimensions. The figsize parameter takes a tuple (width, height) in inches. Larger figures are better for detailed visualizations or when you need to include them in presentations. The dpi parameter controls resolution.

In [None]:
# Create a larger, high-resolution figure
plt.figure(figsize=(10, 6))

plt.plot(x, y, linewidth=2)

plt.title('Linear Relationship', fontsize=14)
plt.xlabel('X Values', fontsize=12)
plt.ylabel('Y Values', fontsize=12)
plt.grid(True, alpha=0.3)

plt.show()

### Exercise 1: Creating a Complete Plot

Create a line plot showing daily website traffic for a week:
1. Days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
2. Visitors = [1200, 1350, 1100, 1400, 1600, 2100, 1900]
3. Set figure size to (10, 6)
4. Use markers ('o') and line width of 2
5. Add title 'Daily Website Traffic'
6. Add axis labels for 'Day' and 'Number of Visitors'
7. Add a grid with alpha=0.3

In [None]:
# Your code here


In [None]:
# Solution
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
visitors = [1200, 1350, 1100, 1400, 1600, 2100, 1900]

plt.figure(figsize=(10, 6))

plt.plot(days, visitors, marker='o', linewidth=2)

plt.title('Daily Website Traffic')
plt.xlabel('Day')
plt.ylabel('Number of Visitors')
plt.grid(True, alpha=0.3)

plt.show()

## Part 2: Line Plots for Trends

### Basic Line Plot with Real Data

Line plots are ideal for showing trends over time or changes along a continuous variable. They connect data points with lines, emphasizing the flow and direction of change. Use line plots when you have sequential data and want to show how values evolve.

In [None]:
# Monthly sales data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
sales = [12000, 15000, 18000, 22000, 25000, 28000]

plt.figure(figsize=(10, 6))
plt.plot(months, sales, marker='o', linewidth=2)

plt.title('Monthly Sales Trend - 2024')
plt.xlabel('Month')
plt.ylabel('Sales ($)')

plt.show()

### Multiple Lines for Comparison

To compare multiple trends on the same axes, plot multiple lines. Each call to plt.plot() adds another line to the current figure. Use different colors and markers to distinguish between series, and always include a legend to explain what each line represents.

In [None]:
# Compare sales across regions
east_sales = [12000, 15000, 18000, 22000, 25000, 28000]
west_sales = [10000, 12000, 15000, 18000, 20000, 22000]
north_sales = [8000, 9000, 11000, 13000, 15000, 17000]

plt.figure(figsize=(10, 6))

plt.plot(months, east_sales, marker='o', label='East')
plt.plot(months, west_sales, marker='s', label='West')
plt.plot(months, north_sales, marker='^', label='North')

plt.title('Regional Sales Comparison - 2024')
plt.xlabel('Month')
plt.ylabel('Sales ($)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.show()

### Customizing Line Appearance

Matplotlib provides extensive customization options for lines. You can control color, line style (solid, dashed, dotted), line width, marker style (circle, square, triangle), and marker size. This customization helps distinguish between multiple series and emphasize important data.

In [None]:
plt.figure(figsize=(10, 6))

# Customize each line
plt.plot(months, east_sales, color='blue', linestyle='-',
         linewidth=2, marker='o', markersize=8, label='East')
plt.plot(months, west_sales, color='red', linestyle='--',
         linewidth=2, marker='s', markersize=8, label='West')
plt.plot(months, north_sales, color='green', linestyle=':',
         linewidth=2, marker='^', markersize=8, label='North')

plt.title('Customized Regional Sales')
plt.xlabel('Month')
plt.ylabel('Sales ($)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.show()

### Exercise 2: Line Plot Comparison

Create quarterly revenue data for two products:
1. Product A = [45000, 52000, 48000, 61000] for Q1-Q4
2. Product B = [38000, 41000, 55000, 58000] for Q1-Q4
3. Plot both on the same figure with different colors, line styles, and markers
4. Add a legend, title, and axis labels

In [None]:
# Your code here


In [None]:
# Solution
quarters = ['Q1', 'Q2', 'Q3', 'Q4']
product_a = [45000, 52000, 48000, 61000]
product_b = [38000, 41000, 55000, 58000]

plt.figure(figsize=(10, 6))

plt.plot(quarters, product_a, marker='o', linestyle='-', 
         color='blue', linewidth=2, label='Product A')
plt.plot(quarters, product_b, marker='s', linestyle='--', 
         color='red', linewidth=2, label='Product B')

plt.title('Quarterly Revenue Comparison')
plt.xlabel('Quarter')
plt.ylabel('Revenue ($)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.show()

## Part 3: Bar Charts for Comparisons

### Basic Bar Chart

Bar charts display categorical data with rectangular bars whose lengths represent values. They are excellent for comparing quantities across different categories. The visual comparison of bar heights makes it easy to see which categories have larger or smaller values.

In [None]:
# Average grades by department
departments = ['CS', 'Math', 'Engineering', 'Physics']
avg_grades = [85.5, 82.3, 88.1, 79.8]

plt.figure(figsize=(10, 6))
plt.bar(departments, avg_grades, color='steelblue')

plt.title('Average Grade by Department')
plt.xlabel('Department')
plt.ylabel('Average Grade')

plt.show()

### Adding Value Labels to Bars

While bar heights show relative comparisons, exact values are often needed. Adding value labels on top of each bar provides both the visual comparison and precise numbers. This is a common practice in business reports and presentations.

In [None]:
plt.figure(figsize=(10, 6))
bars = plt.bar(departments, avg_grades, color='steelblue')

# Add value labels on top of each bar
for i, v in enumerate(avg_grades):
    plt.text(i, v + 0.5, f'{v:.1f}', ha='center', fontsize=10)

plt.title('Average Grade by Department')
plt.xlabel('Department')
plt.ylabel('Average Grade')
plt.ylim(0, 95)

plt.show()

### Horizontal Bar Charts

When category names are long, horizontal bars are more readable. Use plt.barh() instead of plt.bar(). The horizontal orientation gives more space for category labels and is easier to read when you have many categories.

In [None]:
# Products with longer names
products = ['Premium Widget Pro', 'Basic Gadget Standard',
            'Enterprise Tool Suite', 'Starter Kit Basic']
sales_values = [45000, 32000, 58000, 28000]

plt.figure(figsize=(10, 6))
plt.barh(products, sales_values, color='teal')

plt.title('Sales by Product')
plt.xlabel('Sales ($)')
plt.ylabel('Product')

plt.show()

### Grouped Bar Charts

To compare multiple values per category, use grouped (clustered) bars. This requires calculating bar positions manually using numpy arrays. The key is to offset each group of bars so they appear side by side.

In [None]:
# Q1 and Q2 sales by region
regions = ['East', 'West', 'North', 'South']
q1_sales = [45000, 38000, 32000, 41000]
q2_sales = [52000, 45000, 38000, 48000]

x = np.arange(len(regions))
width = 0.35

plt.figure(figsize=(10, 6))

plt.bar(x - width/2, q1_sales, width, label='Q1', color='steelblue')
plt.bar(x + width/2, q2_sales, width, label='Q2', color='coral')

plt.title('Quarterly Sales Comparison')
plt.xlabel('Region')
plt.ylabel('Sales ($)')
plt.xticks(x, regions)
plt.legend()

plt.show()

### Stacked Bar Charts

While grouped bar charts place bars side by side, stacked bar charts place bars on top of each other. This is useful when you want to show both the total and the composition of that total. The key technique is using the 'bottom' parameter to specify where each bar starts.

In [None]:
# Create stacked bar chart
regions = ['East', 'West', 'North', 'South']
q1_sales = [45000, 38000, 32000, 41000]
q2_sales = [52000, 45000, 38000, 48000]

plt.figure(figsize=(10, 6))

# First bar (Q1) - starts at bottom (0)
plt.bar(regions, q1_sales, label='Q1', color='steelblue')

# Second bar (Q2) - stacked on top of Q1
plt.bar(regions, q2_sales, bottom=q1_sales, label='Q2', color='coral')

plt.title('Stacked Quarterly Sales by Region')
plt.xlabel('Region')
plt.ylabel('Total Sales ($)')
plt.legend()

plt.show()

The stacked bar chart clearly shows both the total sales for each region (bar height) and how that total breaks down between Q1 and Q2. Notice how the 'bottom' parameter for Q2 is set to the Q1 values, making Q2 bars start where Q1 bars end.

### Exercise 3: Grouped Bar Chart

Create a grouped bar chart comparing Morning, Afternoon, and Evening sales for Monday-Friday:
1. Morning = [120, 135, 128, 142, 118]
2. Afternoon = [95, 102, 98, 110, 105]
3. Evening = [80, 88, 92, 85, 95]
4. Add a legend with appropriate labels

In [None]:
# Your code here


In [None]:
# Solution
days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri']
morning = [120, 135, 128, 142, 118]
afternoon = [95, 102, 98, 110, 105]
evening = [80, 88, 92, 85, 95]

x = np.arange(len(days))
width = 0.25

plt.figure(figsize=(10, 6))

plt.bar(x - width, morning, width, label='Morning')
plt.bar(x, afternoon, width, label='Afternoon')
plt.bar(x + width, evening, width, label='Evening')

plt.title('Daily Sales by Time Period')
plt.xlabel('Day')
plt.ylabel('Sales')
plt.xticks(x, days)
plt.legend()

plt.show()

## Part 4: Scatter Plots for Relationships

### Basic Scatter Plot

Scatter plots show individual data points as dots, revealing relationships between two numeric variables. They are excellent for identifying correlations (do variables increase together?), clusters (are there groups?), and outliers (are there unusual points?). Each point represents one observation.

In [None]:
# Generate study hours vs test scores data
np.random.seed(42)
study_hours = np.random.uniform(1, 10, 50)
test_scores = 50 + 4 * study_hours + np.random.normal(0, 5, 50)

plt.figure(figsize=(10, 6))
plt.scatter(study_hours, test_scores, alpha=0.7)

plt.title('Study Hours vs Test Scores')
plt.xlabel('Study Hours')
plt.ylabel('Test Score')

plt.show()

### Customizing Scatter Plots

Scatter plots can encode additional information through marker size and color. The color can represent a third variable using a colormap, and size can represent a fourth. This allows you to visualize multiple dimensions of data in a single plot.

In [None]:
plt.figure(figsize=(10, 6))

# Color by test score using a colormap
scatter = plt.scatter(study_hours, test_scores,
                      c=test_scores,
                      cmap='viridis',
                      alpha=0.7,
                      s=100)

plt.colorbar(scatter, label='Test Score')

plt.title('Study Hours vs Test Scores')
plt.xlabel('Study Hours')
plt.ylabel('Test Score')

plt.show()

### Adding a Trend Line

A trend line (regression line) quantifies the relationship you see in the scatter. It shows the average direction and strength of the relationship. We can fit a linear regression using numpy's polyfit function and plot the resulting line.

In [None]:
# Fit a linear regression
coefficients = np.polyfit(study_hours, test_scores, 1)
poly = np.poly1d(coefficients)

# Create line points
x_line = np.linspace(study_hours.min(), study_hours.max(), 100)
y_line = poly(x_line)

plt.figure(figsize=(10, 6))

plt.scatter(study_hours, test_scores, alpha=0.6, label='Students')
plt.plot(x_line, y_line, color='red', linewidth=2,
         label=f'Trend: y = {coefficients[0]:.1f}x + {coefficients[1]:.1f}')

plt.title('Study Hours vs Test Scores')
plt.xlabel('Study Hours')
plt.ylabel('Test Score')
plt.legend()

plt.show()

### Exercise 4: Scatter Plot with Regression

Generate data for 75 houses:
1. Square footage: random values between 1000-4000
2. Price: sqft * 150 + random noise
3. Create a scatter plot with alpha=0.6, color by price, add colorbar
4. Fit and plot a regression line

In [None]:
# Your code here


In [None]:
# Solution
np.random.seed(42)
sqft = np.random.uniform(1000, 4000, 75)
price = sqft * 150 + np.random.normal(0, 30000, 75)

# Fit regression
coeffs = np.polyfit(sqft, price, 1)
poly = np.poly1d(coeffs)
x_line = np.linspace(sqft.min(), sqft.max(), 100)

plt.figure(figsize=(10, 6))

scatter = plt.scatter(sqft, price, c=price, cmap='viridis', 
                      alpha=0.6, s=80)
plt.colorbar(scatter, label='Price ($)')

plt.plot(x_line, poly(x_line), 'r-', linewidth=2, 
         label=f'Trend: ${coeffs[0]:.0f}/sqft')

plt.title('House Price vs Square Footage')
plt.xlabel('Square Footage')
plt.ylabel('Price ($)')
plt.legend()

plt.show()

print(f"Price per square foot: ${coeffs[0]:.2f}")

## Part 5: Histograms for Distributions

### Basic Histogram

Histograms show how data is distributed across a range of values. They divide data into bins (intervals) and count how many values fall into each bin. The shape of a histogram reveals important characteristics: is the data symmetric or skewed? Are there multiple peaks? Are there outliers?

In [None]:
# Generate test score data
np.random.seed(42)
scores = np.random.normal(75, 10, 200)

plt.figure(figsize=(10, 6))
plt.hist(scores, bins=20, edgecolor='black')

plt.title('Distribution of Test Scores')
plt.xlabel('Score')
plt.ylabel('Number of Students')

plt.show()

### Adding Statistical Markers

Vertical lines showing the mean and median help interpret the distribution. If mean equals median, the distribution is symmetric. If mean is greater than median, it's right-skewed. These markers provide quick statistical insight into the data.

In [None]:
plt.figure(figsize=(10, 6))

plt.hist(scores, bins=25, edgecolor='black', color='steelblue', alpha=0.7)

# Add mean and median lines
mean_score = np.mean(scores)
median_score = np.median(scores)

plt.axvline(mean_score, color='red', linestyle='--',
            linewidth=2, label=f'Mean: {mean_score:.1f}')
plt.axvline(median_score, color='green', linestyle=':',
            linewidth=2, label=f'Median: {median_score:.1f}')

plt.title('Distribution of Test Scores')
plt.xlabel('Score')
plt.ylabel('Number of Students')
plt.legend()

plt.show()

### Comparing Distributions

Use multiple overlapping histograms to compare distributions across groups. Setting alpha (transparency) to 0.5 allows you to see overlapping regions. This is useful for comparing different populations or conditions.

In [None]:
# Compare three departments
np.random.seed(42)
cs_scores = np.random.normal(78, 8, 100)
math_scores = np.random.normal(75, 12, 100)
eng_scores = np.random.normal(82, 6, 100)

plt.figure(figsize=(10, 6))

plt.hist(cs_scores, bins=20, alpha=0.5, label='CS')
plt.hist(math_scores, bins=20, alpha=0.5, label='Math')
plt.hist(eng_scores, bins=20, alpha=0.5, label='Engineering')

plt.title('Score Distributions by Department')
plt.xlabel('Score')
plt.ylabel('Number of Students')
plt.legend()

plt.show()

### Exercise 5: Comparing Distributions

Generate wait times for two restaurants:
1. Restaurant A: exponential distribution with mean=15 (150 samples)
2. Restaurant B: normal distribution with mean=20, std=3 (150 samples)
3. Create overlapping histograms with alpha=0.5
4. Add mean and median lines for each restaurant

In [None]:
# Your code here


In [None]:
# Solution
np.random.seed(42)
rest_a = np.random.exponential(15, 150)
rest_b = np.random.normal(20, 3, 150)

plt.figure(figsize=(10, 6))

plt.hist(rest_a, bins=25, alpha=0.5, label='Restaurant A')
plt.hist(rest_b, bins=25, alpha=0.5, label='Restaurant B')

# Add statistics for A
plt.axvline(np.mean(rest_a), color='blue', linestyle='--', 
            label=f'A Mean: {np.mean(rest_a):.1f}')
plt.axvline(np.median(rest_a), color='blue', linestyle=':', 
            label=f'A Median: {np.median(rest_a):.1f}')

# Add statistics for B
plt.axvline(np.mean(rest_b), color='orange', linestyle='--', 
            label=f'B Mean: {np.mean(rest_b):.1f}')
plt.axvline(np.median(rest_b), color='orange', linestyle=':', 
            label=f'B Median: {np.median(rest_b):.1f}')

plt.title('Restaurant Wait Times')
plt.xlabel('Wait Time (min)')
plt.ylabel('Frequency')
plt.legend(loc='upper right', fontsize=8)

plt.show()

## Part 6: Plot Customization and Subplots

### Multiple Plots in One Figure

The plt.subplots() function creates a grid of axes within one figure. This is useful for comparing multiple visualizations side by side or creating dashboard-style displays. You can arrange plots in rows and columns.

In [None]:
# Create 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Top-left: Line plot
axes[0, 0].plot(months, east_sales, marker='o')
axes[0, 0].set_title('Sales Trend')
axes[0, 0].set_xlabel('Month')
axes[0, 0].set_ylabel('Sales ($)')

# Top-right: Bar chart
axes[0, 1].bar(departments, avg_grades)
axes[0, 1].set_title('Grades by Department')
axes[0, 1].set_xlabel('Department')
axes[0, 1].set_ylabel('Average Grade')

# Bottom-left: Histogram
axes[1, 0].hist(scores, bins=20, edgecolor='black')
axes[1, 0].set_title('Score Distribution')
axes[1, 0].set_xlabel('Score')
axes[1, 0].set_ylabel('Count')

# Bottom-right: Scatter
axes[1, 1].scatter(study_hours, test_scores, alpha=0.6)
axes[1, 1].set_title('Study Hours vs Scores')
axes[1, 1].set_xlabel('Study Hours')
axes[1, 1].set_ylabel('Test Score')

plt.tight_layout()
plt.show()

### Exercise 6: Multiple Subplots

Create a 1x3 subplot showing monthly sales data three ways:
1. Line plot with markers
2. Bar chart
3. Scatter plot
4. Use consistent colors and add annotation for peak month in line plot

In [None]:
# Your code here


In [None]:
# Solution
months_num = [1, 2, 3, 4, 5, 6]
month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
sales_vals = [12000, 15000, 18000, 22000, 19000, 28000]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Line plot
axes[0].plot(month_names, sales_vals, 'b-o', linewidth=2, markersize=8)
max_idx = sales_vals.index(max(sales_vals))
axes[0].annotate('Peak!', xy=(max_idx, sales_vals[max_idx]),
                 xytext=(max_idx-1, sales_vals[max_idx]+2000),
                 arrowprops=dict(arrowstyle='->', color='red'))
axes[0].set_title('Trend')
axes[0].set_xlabel('Month')
axes[0].set_ylabel('Sales ($)')

# Bar chart
axes[1].bar(month_names, sales_vals, color='steelblue')
axes[1].set_title('Comparison')
axes[1].set_xlabel('Month')
axes[1].set_ylabel('Sales ($)')

# Scatter plot
axes[2].scatter(months_num, sales_vals, s=100, c='steelblue')
axes[2].set_title('Distribution')
axes[2].set_xlabel('Month')
axes[2].set_ylabel('Sales ($)')

plt.tight_layout()
plt.show()

## Part 7: Introduction to Seaborn

### Why Seaborn?

Seaborn provides a high-level interface for statistical visualizations with attractive default styling. It integrates beautifully with pandas DataFrames - you can specify column names directly instead of extracting arrays. Seaborn is built on matplotlib, so you can still customize plots using matplotlib commands.

In [None]:
# Set seaborn style
sns.set_style('whitegrid')

# Create sample DataFrame
np.random.seed(42)
df = pd.DataFrame({
    'Department': np.random.choice(['CS', 'Math', 'Engineering'], 100),
    'Experience': np.random.randint(1, 20, 100),
    'Salary': np.random.normal(70000, 15000, 100)
})

print("Sample data:")
print(df.head())

### Seaborn Scatter with Hue

Seaborn integrates beautifully with pandas. The hue parameter automatically colors points by category, creating a legend automatically. This makes it easy to visualize how a third categorical variable affects the relationship.

In [None]:
plt.figure(figsize=(10, 6))

sns.scatterplot(data=df, x='Experience', y='Salary', hue='Department')

plt.title('Salary vs Experience by Department')
plt.show()

### Box Plots for Distributions

Box plots show distribution statistics compactly: the median (line in middle), quartiles (box edges), and outliers (individual points). They're excellent for comparing distributions across categories and identifying outliers.

In [None]:
plt.figure(figsize=(10, 6))

sns.boxplot(data=df, x='Department', y='Salary')

plt.title('Salary Distribution by Department')
plt.show()

### Violin Plots

Violin plots combine box plots with kernel density estimation, showing the full distribution shape. The width of the violin at each point shows the density of data at that value. This provides more information than a box plot alone.

In [None]:
plt.figure(figsize=(10, 6))

sns.violinplot(data=df, x='Department', y='Salary')

plt.title('Salary Distribution by Department (Violin)')
plt.show()

### Regression Plot

Seaborn's regplot combines scatter plots with regression lines in one function. It automatically calculates and displays the regression line with confidence interval. This is much easier than manually fitting a line with numpy.

In [None]:
plt.figure(figsize=(10, 6))

sns.regplot(data=df, x='Experience', y='Salary',
            scatter_kws={'alpha': 0.5})

plt.title('Salary vs Experience with Regression')
plt.show()

### Pair Plots for Multi-Variable Relationships

When you have multiple numeric variables, pair plots show scatter plots for every pair of variables. This is invaluable for exploratory data analysis, helping you quickly identify which variables are related. The diagonal shows the distribution of each variable.

In [None]:
# Create a DataFrame with multiple numeric columns
np.random.seed(42)
multi_df = pd.DataFrame({
    'Experience': np.random.randint(1, 20, 50),
    'Salary': np.random.normal(70000, 15000, 50),
    'Performance': np.random.uniform(60, 100, 50),
    'Department': np.random.choice(['CS', 'Math', 'Eng'], 50)
})

# Create pair plot colored by department
sns.pairplot(multi_df, hue='Department')
plt.suptitle('Pair Plot of Employee Metrics', y=1.02)
plt.show()

### Heatmaps for Correlation Matrices

A heatmap visualizes a matrix of numbers using colors. This is particularly useful for showing correlation matrices, where you can quickly see which variables are positively or negatively correlated. The color intensity represents the strength of the relationship.

In [None]:
# Calculate correlation matrix
numeric_cols = multi_df.select_dtypes(include=[np.number])
correlation = numeric_cols.corr()

print("Correlation Matrix:")
print(correlation.round(2))

# Create heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(correlation, 
            annot=True,          # Show numbers in cells
            cmap='coolwarm',     # Color scheme (blue to red)
            center=0,            # Center color at 0
            fmt='.2f',           # Number format
            square=True)         # Square cells

plt.title('Correlation Heatmap of Employee Metrics')
plt.show()

The heatmap makes patterns immediately visible:
- Red cells indicate positive correlation (variables increase together)
- Blue cells indicate negative correlation (one increases as other decreases)
- White/pale cells indicate little or no correlation
- The diagonal is always 1.0 (each variable correlates perfectly with itself)

### Exercise 7: Seaborn Statistical Plots

Create a DataFrame with 80 employees:
1. Department (Sales/IT/HR)
2. Years_Experience (1-15)
3. Salary (50000 + Years*3000 + noise)
4. Create seaborn scatter, box, and violin plots

In [None]:
# Your code here


In [None]:
# Solution
np.random.seed(42)
emp_df = pd.DataFrame({
    'Department': np.random.choice(['Sales', 'IT', 'HR'], 80),
    'Years': np.random.randint(1, 16, 80)
})
emp_df['Salary'] = 50000 + emp_df['Years'] * 3000 + np.random.normal(0, 5000, 80)

# Scatter plot
plt.figure(figsize=(10, 6))
sns.scatterplot(data=emp_df, x='Years', y='Salary', hue='Department')
plt.title('Salary vs Experience by Department')
plt.show()

# Box plot
plt.figure(figsize=(10, 6))
sns.boxplot(data=emp_df, x='Department', y='Salary')
plt.title('Salary Distribution by Department')
plt.show()

# Violin plot
plt.figure(figsize=(10, 6))
sns.violinplot(data=emp_df, x='Department', y='Years')
plt.title('Experience Distribution by Department')
plt.show()

## Part 8: Saving Plots and Complete Workflow

### Saving to Files

To use your plots in reports, presentations, or websites, you need to save them as image files. The plt.savefig() function saves the current figure. Always call savefig() before show() because show() clears the figure.

In [None]:
# Create a plot to save
plt.figure(figsize=(10, 6))
plt.plot(months, east_sales, marker='o', linewidth=2)
plt.title('Monthly Sales Trend')
plt.xlabel('Month')
plt.ylabel('Sales ($)')
plt.grid(True, alpha=0.3)

# Save before showing
plt.savefig('sales_trend.png', dpi=150, bbox_inches='tight')
print("Plot saved to sales_trend.png")

plt.show()

### Exercise 8: Dashboard Creation

Create a 2x2 dashboard with:
1. Line plot of monthly trend
2. Bar chart by category
3. Histogram of values
4. Scatter with regression
5. Save as dashboard.png with dpi=150

In [None]:
# Your code here


In [None]:
# Solution
# Generate sample data
np.random.seed(42)
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
trend = [100, 120, 115, 140, 160, 180]
categories = ['A', 'B', 'C', 'D']
cat_vals = [45, 32, 58, 41]
dist_data = np.random.normal(50, 10, 200)
x_scatter = np.random.uniform(0, 10, 50)
y_scatter = 2*x_scatter + np.random.normal(0, 2, 50)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Line plot
axes[0, 0].plot(months, trend, 'b-o', linewidth=2)
axes[0, 0].set_title('Monthly Trend')
axes[0, 0].set_xlabel('Month')
axes[0, 0].set_ylabel('Value')

# Bar chart
axes[0, 1].bar(categories, cat_vals, color='teal')
axes[0, 1].set_title('By Category')
axes[0, 1].set_xlabel('Category')
axes[0, 1].set_ylabel('Value')

# Histogram
axes[1, 0].hist(dist_data, bins=20, edgecolor='black')
axes[1, 0].set_title('Distribution')
axes[1, 0].set_xlabel('Value')
axes[1, 0].set_ylabel('Frequency')

# Scatter with regression
axes[1, 1].scatter(x_scatter, y_scatter, alpha=0.6)
coeffs = np.polyfit(x_scatter, y_scatter, 1)
x_line = np.linspace(0, 10, 100)
axes[1, 1].plot(x_line, np.poly1d(coeffs)(x_line), 'r-')
axes[1, 1].set_title('Relationship')
axes[1, 1].set_xlabel('X')
axes[1, 1].set_ylabel('Y')

plt.tight_layout()
plt.savefig('dashboard.png', dpi=150, bbox_inches='tight')
plt.show()

print("Dashboard saved to dashboard.png")

## Summary

In this lecture, you learned data visualization fundamentals:

1. **Introduction to Matplotlib and visualization** - Understanding why we visualize, choosing chart types, and matplotlib basics
2. **Line plots** - Showing trends and comparing multiple series
3. **Bar charts** - Comparing categories with vertical, horizontal, grouped, and stacked bars
4. **Scatter plots** - Revealing relationships between variables with trend lines
5. **Histograms** - Understanding distributions and comparing them
6. **Plot customization and subplots** - Multiple plots in one figure
7. **Seaborn** - Statistical plots including box plots, violin plots, pair plots, and heatmaps
8. **Saving and workflow** - Exporting plots for reports

Remember: Choose the chart type based on your data and question. Always include titles and labels. Use color and style purposefully to enhance understanding.