DocsArtifacts

Managing Training Artifacts

Learn how to handle training outputs, model files, and other artifacts in Trainwave. This guide covers storage, retrieval, and best practices for artifact management.

Quick Start

# Download artifacts from a job
wave storage download j-xyz789 --output ./results
 
# List artifacts in a job
wave storage list j-xyz789
 
# Clean up old artifacts
wave storage cleanup --older-than 30d

Artifact Storage

Storage Structure

Trainwave automatically manages artifacts in the following directory structure:

/workspace/
├── artifacts/              # Main artifacts directory
│   ├── models/            # Trained models
│   ├── checkpoints/       # Training checkpoints
│   ├── logs/             # Training logs
│   └── results/          # Evaluation results
├── data/                  # Input data
└── src/                  # Your source code

Saving Artifacts

Save your training outputs to the appropriate directories:

# PyTorch example
import torch
 
# Save model
torch.save(model.state_dict(), '/workspace/artifacts/models/model.pt')
 
# Save checkpoint
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, '/workspace/artifacts/checkpoints/checkpoint.pt')
# TensorFlow example
import tensorflow as tf
 
# Save model
model.save('/workspace/artifacts/models/model')
 
# Save checkpoint
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint.save('/workspace/artifacts/checkpoints/ckpt')

Artifact Management

CLI Commands

# List artifacts
wave storage list j-xyz789
 
# Download specific artifacts
wave storage download j-xyz789 \
  --include "*.pt" \
  --output ./models
 
# Download all artifacts
wave storage download j-xyz789 \
  --output ./results
 
# Stream logs in real-time
wave storage logs j-xyz789 -f
 
# Clean up old artifacts
wave storage cleanup \
  --older-than 30d \
  --exclude "*.pt"

Automatic Artifact Collection

Trainwave automatically collects:

  1. Training logs (/workspace/artifacts/logs/)
  2. Model files (/workspace/artifacts/models/)
  3. Metrics and results (/workspace/artifacts/results/)
  4. Environment information
  5. Resource usage statistics

Integration with ML Frameworks

PyTorch

import torch
from pathlib import Path
 
class ModelCheckpoint:
    def __init__(self, model, optimizer, save_dir):
        self.model = model
        self.optimizer = optimizer
        self.save_dir = Path('/workspace/artifacts/checkpoints') / save_dir
        self.save_dir.mkdir(parents=True, exist_ok=True)
 
    def save(self, epoch, loss):
        checkpoint_path = self.save_dir / f'checkpoint_epoch_{epoch}.pt'
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': loss,
        }, checkpoint_path)
 
    def load(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch'], checkpoint['loss']

TensorFlow

import tensorflow as tf
import os
 
class TrainingCallback(tf.keras.callbacks.Callback):
    def __init__(self, checkpoint_dir):
        super().__init__()
        self.checkpoint_dir = os.path.join('/workspace/artifacts/checkpoints', checkpoint_dir)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
 
    def on_epoch_end(self, epoch, logs=None):
        # Save checkpoint
        checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}')
        self.model.save_weights(checkpoint_path)
 
        # Save metrics
        with open(os.path.join(self.checkpoint_dir, 'metrics.txt'), 'a') as f:
            f.write(f'Epoch {epoch}: {logs}\n')

Cloud Storage Integration

AWS S3

import boto3
from pathlib import Path
 
def upload_artifacts(job_id, local_dir='/workspace/artifacts'):
    s3 = boto3.client('s3')
    bucket = 'your-artifact-bucket'
 
    for path in Path(local_dir).rglob('*'):
        if path.is_file():
            s3_key = f'jobs/{job_id}/{path.relative_to(local_dir)}'
            s3.upload_file(str(path), bucket, s3_key)

Google Cloud Storage

from google.cloud import storage
from pathlib import Path
 
def upload_artifacts(job_id, local_dir='/workspace/artifacts'):
    client = storage.Client()
    bucket = client.get_bucket('your-artifact-bucket')
 
    for path in Path(local_dir).rglob('*'):
        if path.is_file():
            blob_name = f'jobs/{job_id}/{path.relative_to(local_dir)}'
            blob = bucket.blob(blob_name)
            blob.upload_from_filename(str(path))

Best Practices

1. Organization

  • Use consistent directory structure
  • Follow clear naming conventions
  • Separate different types of artifacts
  • Document artifact formats

2. Storage Efficiency

  • Compress large files
  • Clean up old artifacts
  • Use appropriate file formats
  • Implement retention policies

3. Versioning

  • Include version information
  • Track dependencies
  • Document model architecture
  • Save training configuration

Common Patterns

1. Model Training

import torch
from pathlib import Path
 
def save_training_artifacts(model, metrics, config, run_dir):
    # Create artifact directories
    artifact_dir = Path('/workspace/artifacts') / run_dir
    model_dir = artifact_dir / 'models'
    metrics_dir = artifact_dir / 'metrics'
    config_dir = artifact_dir / 'config'
 
    # Create directories
    for d in [model_dir, metrics_dir, config_dir]:
        d.mkdir(parents=True, exist_ok=True)
 
    # Save model
    torch.save(model.state_dict(), model_dir / 'model.pt')
 
    # Save metrics
    with open(metrics_dir / 'metrics.json', 'w') as f:
        json.dump(metrics, f)
 
    # Save config
    with open(config_dir / 'config.json', 'w') as f:
        json.dump(config, f)

2. Experiment Tracking

import wandb
from pathlib import Path
 
def track_experiment(model, config, artifacts_dir):
    run = wandb.init(project="my-project", config=config)
 
    # Log model artifacts
    model_path = Path(artifacts_dir) / 'models' / 'model.pt'
    wandb.save(str(model_path))
 
    # Log metrics
    wandb.log({
        "train_loss": train_loss,
        "val_loss": val_loss,
        "learning_rate": config["lr"]
    })

3. Distributed Training

import torch.distributed as dist
from pathlib import Path
 
def save_distributed_checkpoint(model, optimizer, epoch, rank):
    if rank == 0:  # Save only on main process
        checkpoint_dir = Path('/workspace/artifacts/checkpoints')
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
 
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_dir / f'checkpoint_epoch_{epoch}.pt')
 
    dist.barrier()  # Wait for save to complete

Troubleshooting

Common Issues

  1. Storage Space

    # Check storage usage
    wave storage usage j-xyz789
     
    # Clean up old artifacts
    wave storage cleanup --older-than 7d
  2. Missing Artifacts

    # Verify artifact paths
    wave storage list j-xyz789 --verbose
     
    # Check job logs
    wave jobs logs j-xyz789 | grep "saving"
  3. Download Issues

    # Retry with verbose output
    wave storage download j-xyz789 \
      --output ./results \
      --verbose

Support