DocsQuickstart

Getting Started with Trainwave

This guide will walk you through setting up your first machine learning training job on Trainwave. We’ll use a simple PyTorch example, but you can adapt these steps for any ML framework.

Prerequisites

  • Python 3.8 or later
  • pip package manager
  • A Trainwave account (Sign up here)

Step 1: Install the CLI

The Trainwave CLI is your primary tool for managing training jobs. Install it using pip:

pip install trainwave-cli

Step 2: Authenticate

Log in to your Trainwave account through the CLI:

wave auth login

This will open your browser for authentication. Alternatively, create an API key in the dashboard and set it:

wave auth set-token <your-api-key>

Verify your authentication:

wave auth whoami

Step 3: Create an Organization

Organizations help you manage projects and billing. Create one through our web interface.

Example organization structure:

MyOrg
├── Project A (ML Research)
├── Project B (Production Models)
└── Project C (Experiments)

Step 4: Create a Project

Create a new project in your organization:

  1. Go to Projects Dashboard
  2. Click “New Project”
  3. Note your project ID (format: p-xxxxxxxx)

Step 5: Configure Your Training Job

Create a trainwave.toml in your project directory. Here’s a complete example for a PyTorch training job:

# Basic Configuration
name = "pytorch-mnist"
project = "p-your-project-id"
description = "Training MNIST classifier using PyTorch"
 
# Resource Configuration
gpu_type = "RTX A5000"
gpus = 1
cpu_cores = 4
memory_gb = 16
hdd_size_mb = 51200
 
# Runtime Configuration
image = "trainwave/pytorch:2.3.1"
setup_command = """
pip install -r requirements.txt
wandb login ${WANDB_API_KEY}
"""
run_command = "python train.py"
 
# Environment Variables
[env_vars]
WANDB_API_KEY = "${WANDB_API_KEY}"
PYTORCH_CUDA_ALLOC_CONF = "max_split_size_mb:512"
 
# Optional Settings
expires = "4h"                    # Auto-terminate after 4 hours
compliance_soc2 = true           # Enable SOC2 compliance

Step 6: Prepare Your Code

Here’s a minimal example of a PyTorch training script (train.py):

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import wandb
 
# Initialize wandb
wandb.init(project="mnist-example")
 
# Define model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)
 
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
 
# Setup training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
 
# Load data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=64)
 
# Training loop
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
 
        if batch_idx % 100 == 0:
            wandb.log({
                "loss": loss.item(),
                "epoch": epoch
            })
 
# Save model
torch.save(model.state_dict(), "model.pt")
wandb.save("model.pt")

And the corresponding requirements.txt:

torch>=2.0.0
torchvision>=0.15.0
wandb>=0.15.0

Step 7: Launch Your Job

Launch your training job:

wave jobs launch

Monitor your job:

# View job status
wave jobs status
 
# Stream logs
wave jobs logs -f
 
# View GPU metrics
wave jobs metrics

Next Steps

Common Issues and Solutions

GPU Not Detected

If your code can’t detect the GPU:

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")

Out of Memory

Add to your trainwave.toml:

env_vars.PYTORCH_CUDA_ALLOC_CONF = "max_split_size_mb:512"

Job Timeout

Extend the job timeout in trainwave.toml:

expires = "12h"  # Set to appropriate duration

Need help? Join our Discord community or email support@trainwave.ai.