Build and train the CNN model
Training effective neural networks requires careful architecture design, configurable hyperparameters, and robust training loops. Our CNN implementation uses modern best practices including batch normalization, dropout regularization, and adaptive learning rate scheduling to achieve reliable digit classification performance.
Configurable model architecture
The DigitCNN class implements a three-layer convolutional neural network designed specifically for MNIST's 28x28 grayscale images. The architecture follows the principle of progressive feature abstraction—early layers detect edges and simple patterns, while deeper layers combine these into complex shapes for final classification:
class DigitCNN(nn.Module):
    """Improved CNN for MNIST digit classification based on research."""
    def __init__(self, config: ModelConfig = None):
        super().__init__()
        if config is None:
            config = ModelConfig()
        self.config = config
        # First convolutional block
        self.conv1 = nn.Conv2d(
            1, config.conv1_channels, kernel_size=5, padding=2
        )  # 5x5 kernel, maintain size
        self.bn1 = nn.BatchNorm2d(config.conv1_channels) if config.use_batch_norm else nn.Identity()
        self.pool1 = nn.MaxPool2d(2, 2)  # 28x28 -> 14x14
        # Second convolutional block
        self.conv2 = nn.Conv2d(
            config.conv1_channels, config.conv2_channels, kernel_size=5, padding=2
        )
        self.bn2 = nn.BatchNorm2d(config.conv2_channels) if config.use_batch_norm else nn.Identity()
        self.pool2 = nn.MaxPool2d(2, 2)  # 14x14 -> 7x7
        # Third convolutional block (new)
        self.conv3 = nn.Conv2d(
            config.conv2_channels, config.conv3_channels, kernel_size=3, padding=1
        )
        self.bn3 = nn.BatchNorm2d(config.conv3_channels) if config.use_batch_norm else nn.Identity()
        self.pool3 = nn.AdaptiveAvgPool2d((3, 3))  # Adaptive pooling to 3x3
        # Dropout layers
        self.dropout1 = nn.Dropout2d(config.dropout1_rate)
        self.dropout2 = nn.Dropout(config.dropout2_rate)
        # Calculate the flattened size: 3x3 * conv3_channels
        conv_output_size = 3 * 3 * config.conv3_channels
        # Fully connected layers
        self.fc1 = nn.Linear(conv_output_size, config.hidden_size)
        self.fc2 = nn.Linear(config.hidden_size, config.hidden_size // 2)  # Additional FC layer
        self.fc3 = nn.Linear(config.hidden_size // 2, 10)
    def _conv_block(self, x, conv, bn, pool, dropout=None):
        """Apply a convolutional block: conv -> bn -> relu -> pool -> dropout (optional)."""
        x = conv(x)
        x = bn(x)
        x = F.relu(x)
        x = pool(x)
        if dropout is not None:
            x = dropout(x)
        return x
    def _fc_block(self, x, fc, dropout=None):
        """Apply a fully connected block: linear -> relu -> dropout (optional)."""
        x = fc(x)
        x = F.relu(x)
        if dropout is not None:
            x = dropout(x)
        return x
    def forward(self, x):
        """Forward pass through the CNN architecture.
        Input: (batch_size, 1, 28, 28) - MNIST digit images
        Output: (batch_size, 10) - Raw logits for 10 digit classes
        Architecture flow:
        1. Conv1: 28x28 -> 14x14 (5x5 kernel, 32 channels)
        2. Conv2: 14x14 -> 7x7 (5x5 kernel, 64 channels) + spatial dropout
        3. Conv3: 7x7 -> 3x3 (3x3 kernel, 128 channels, adaptive pooling)
        4. Flatten: 3x3*128 = 1152 features
        5. FC layers: 1152 -> 256 -> 128 -> 10 (with dropout)
        """
        # Convolutional layers with progressive downsampling
        x = self._conv_block(x, self.conv1, self.bn1, self.pool1)
        x = self._conv_block(x, self.conv2, self.bn2, self.pool2, self.dropout1)
        x = self._conv_block(x, self.conv3, self.bn3, self.pool3)
        # Flatten spatial dimensions for fully connected layers
        x = torch.flatten(x, 1)  # Keep batch dimension
        # Fully connected layers with progressive feature reduction
        x = self._fc_block(x, self.fc1, self.dropout2)
        x = self._fc_block(x, self.fc2)
        x = self.fc3(x)  # Final layer - no activation (raw logits)
        return x  # Return raw logits for CrossEntropyLoss
The architecture demonstrates key design principles: progressive downsampling reduces spatial dimensions while increasing feature depth (28×28 → 14×14 → 7×7 → 3×3), batch normalization after each convolution stabilizes training and enables higher learning rates, and strategic dropout prevents overfitting on spatial patterns. The configurable design allows easy experimentation with different channel sizes, dropout rates, and architectural components through the ModelConfig system.
Training configuration system
Rather than hardcoding training parameters, the system uses Dagster's configuration framework to enable experimentation without code modifications. The ModelConfig class centralizes all training hyperparameters, from model architecture to optimization strategies:
class ModelConfig(dg.Config):
    """Configuration for model architecture and training."""
    # Architecture parameters
    conv1_channels: int = 32  # Reduced complexity
    conv2_channels: int = 64
    conv3_channels: int = 128
    dropout1_rate: float = 0.1  # Reduced dropout
    dropout2_rate: float = 0.2
    hidden_size: int = 256
    use_batch_norm: bool = True
    # Training parameters
    batch_size: int = DEFAULT_BATCH_SIZE  # Smaller batch size for better generalization
    learning_rate: float = DEFAULT_LEARNING_RATE  # Reduced learning rate
    epochs: int = DEFAULT_EPOCHS  # Increased epochs
    optimizer_type: str = "adam"  # Changed to Adam
    momentum: float = 0.9
    weight_decay: float = 1e-5  # Reduced weight decay
    # Learning rate scheduling
    use_lr_scheduler: bool = True
    lr_step_size: int = LR_STEP_SIZE
    lr_gamma: float = LR_GAMMA
    # Early stopping
    use_early_stopping: bool = True
    patience: int = EARLY_STOPPING_PATIENCE
    min_delta: float = MIN_DELTA
    # Data augmentation - Research proven techniques
    use_data_augmentation: bool = True
    rotation_degrees: float = 15.0  # Increased from 10
    translation_pixels: float = 0.1  # New parameter
    scale_range_min: float = 0.9  # Split tuple into two floats
    scale_range_max: float = 1.1  # Split tuple into two floats
    # Model saving
    save_model: bool = True
    model_save_dir: str = str(MODELS_DIR)
    model_name_prefix: str = "mnist_cnn"
This configuration approach separates model architecture from training strategy, enabling data scientists to experiment with different hyperparameters through configuration files while keeping the underlying training logic stable. The configuration includes advanced features like learning rate scheduling (StepLR), early stopping with patience, multiple optimizer support, and automatic model persistence with descriptive filenames.
Training asset orchestration
The digit_classifier asset coordinates the entire training process, from data loading through model persistence. This asset demonstrates how Dagster assets can orchestrate complex ML workflows while providing comprehensive logging and metadata generation:
@dg.asset(
    description="Train CNN digit classifier with configurable parameters",
    group_name="model_pipeline",
    required_resource_keys={"model_storage"},
)
def digit_classifier(
    context,
    processed_mnist_data: dict[str, torch.Tensor],
    config: ModelConfig,
) -> DigitCNN:
    """Train a CNN to classify handwritten digits 0-9 with flexible configuration."""
    context.log.info(f"Training with config: {config.model_dump()}")
    train_data = processed_mnist_data["train_data"]
    val_data = processed_mnist_data["val_data"]
    train_labels = processed_mnist_data["train_labels"]
    val_labels = processed_mnist_data["val_labels"]
    # Create data loaders with configurable batch size
    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    # Initialize model with configuration
    model = DigitCNN(config)
    # Train the model - pass context to train_model
    trained_model, train_losses, val_accuracies = train_model(
        context, model, train_loader, val_loader, config
    )
    final_val_accuracy = val_accuracies[-1]
    # Add metadata
    context.add_output_metadata(
        {
            "final_val_accuracy": final_val_accuracy,
            "training_epochs": len(train_losses),
            "configured_epochs": config.epochs,
            "model_parameters": sum(p.numel() for p in trained_model.parameters()),
            "final_train_loss": train_losses[-1],
            "learning_rate": config.learning_rate,
            "batch_size": config.batch_size,
            "optimizer": config.optimizer_type,
            "early_stopping_used": config.use_early_stopping,
        }
    )
    context.log.info(
        f"Model training completed. Final validation accuracy: {final_val_accuracy:.2f}%"
    )
    # Save model as pickle file if requested
    if config.save_model:
        # Create models directory if it doesn't exist
        model_dir = Path(config.model_save_dir)
        model_dir.mkdir(exist_ok=True)
        # Create filename with timestamp and accuracy
        from datetime import datetime
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        accuracy_str = f"{final_val_accuracy:.2f}".replace(".", "p")
        filename = f"{config.model_name_prefix}_{timestamp}_acc{accuracy_str}.pkl"
        # Save the trained model
        context.log.info(f"Saving model as {filename}")
        model_store = context.resources.model_storage
        model_store.save_model(trained_model, filename)
        context.add_output_metadata(
            {"model_name": filename, "final_accuracy": final_val_accuracy},
            output_name="result",
        )
    return trained_model
The training asset integrates seamlessly with upstream data processing through Dagster's dependency system, ensuring training only begins after data preprocessing completes. It accepts configuration parameters that control all aspects of training behavior, enabling different strategies across development and production environments. The asset generates rich metadata including training metrics, model statistics, and configuration parameters that appear in Dagster's UI for experiment tracking and comparison.
Advanced training features and monitoring
The training system includes sophisticated features for production ML workflows: early stopping monitors validation accuracy and halts training when improvement stagnates (with configurable patience), learning rate scheduling reduces rates during plateaus for better convergence, and comprehensive logging tracks both epoch-level progress and batch-level details for debugging.
Multiple optimizer support (Adam for fast convergence, SGD with momentum for potentially better final performance) provides flexibility for different training scenarios. The system automatically handles GPU/CPU device selection and includes robust error handling for production deployment scenarios.
Model persistence uses descriptive filenames including timestamps and performance metrics, enabling easy model identification and version management. The integration with Dagster's resource system abstracts storage details, supporting both local development and cloud production environments seamlessly.
Next steps
With trained models available through our asset pipeline, the next phase focuses on comprehensive evaluation to assess model performance and determine readiness for production deployment.
- Continue this tutorial with model evaluation and deployment