Managing Training Artifacts
Learn how to handle training outputs, model files, and other artifacts in Trainwave. This guide covers storage, retrieval, and best practices for artifact management.
Quick Start
# Download artifacts from a job
wave storage download j-xyz789 --output ./results
# List artifacts in a job
wave storage list j-xyz789
# Clean up old artifacts
wave storage cleanup --older-than 30d
Artifact Storage
Storage Structure
Trainwave automatically manages artifacts in the following directory structure:
/workspace/
├── artifacts/ # Main artifacts directory
│ ├── models/ # Trained models
│ ├── checkpoints/ # Training checkpoints
│ ├── logs/ # Training logs
│ └── results/ # Evaluation results
├── data/ # Input data
└── src/ # Your source code
Saving Artifacts
Save your training outputs to the appropriate directories:
# PyTorch example
import torch
# Save model
torch.save(model.state_dict(), '/workspace/artifacts/models/model.pt')
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, '/workspace/artifacts/checkpoints/checkpoint.pt')
# TensorFlow example
import tensorflow as tf
# Save model
model.save('/workspace/artifacts/models/model')
# Save checkpoint
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
checkpoint.save('/workspace/artifacts/checkpoints/ckpt')
Artifact Management
CLI Commands
# List artifacts
wave storage list j-xyz789
# Download specific artifacts
wave storage download j-xyz789 \
--include "*.pt" \
--output ./models
# Download all artifacts
wave storage download j-xyz789 \
--output ./results
# Stream logs in real-time
wave storage logs j-xyz789 -f
# Clean up old artifacts
wave storage cleanup \
--older-than 30d \
--exclude "*.pt"
Automatic Artifact Collection
Trainwave automatically collects:
- Training logs (
/workspace/artifacts/logs/
) - Model files (
/workspace/artifacts/models/
) - Metrics and results (
/workspace/artifacts/results/
) - Environment information
- Resource usage statistics
Integration with ML Frameworks
PyTorch
import torch
from pathlib import Path
class ModelCheckpoint:
def __init__(self, model, optimizer, save_dir):
self.model = model
self.optimizer = optimizer
self.save_dir = Path('/workspace/artifacts/checkpoints') / save_dir
self.save_dir.mkdir(parents=True, exist_ok=True)
def save(self, epoch, loss):
checkpoint_path = self.save_dir / f'checkpoint_epoch_{epoch}.pt'
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': loss,
}, checkpoint_path)
def load(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['loss']
TensorFlow
import tensorflow as tf
import os
class TrainingCallback(tf.keras.callbacks.Callback):
def __init__(self, checkpoint_dir):
super().__init__()
self.checkpoint_dir = os.path.join('/workspace/artifacts/checkpoints', checkpoint_dir)
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch, logs=None):
# Save checkpoint
checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}')
self.model.save_weights(checkpoint_path)
# Save metrics
with open(os.path.join(self.checkpoint_dir, 'metrics.txt'), 'a') as f:
f.write(f'Epoch {epoch}: {logs}\n')
Cloud Storage Integration
AWS S3
import boto3
from pathlib import Path
def upload_artifacts(job_id, local_dir='/workspace/artifacts'):
s3 = boto3.client('s3')
bucket = 'your-artifact-bucket'
for path in Path(local_dir).rglob('*'):
if path.is_file():
s3_key = f'jobs/{job_id}/{path.relative_to(local_dir)}'
s3.upload_file(str(path), bucket, s3_key)
Google Cloud Storage
from google.cloud import storage
from pathlib import Path
def upload_artifacts(job_id, local_dir='/workspace/artifacts'):
client = storage.Client()
bucket = client.get_bucket('your-artifact-bucket')
for path in Path(local_dir).rglob('*'):
if path.is_file():
blob_name = f'jobs/{job_id}/{path.relative_to(local_dir)}'
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(path))
Best Practices
1. Organization
- Use consistent directory structure
- Follow clear naming conventions
- Separate different types of artifacts
- Document artifact formats
2. Storage Efficiency
- Compress large files
- Clean up old artifacts
- Use appropriate file formats
- Implement retention policies
3. Versioning
- Include version information
- Track dependencies
- Document model architecture
- Save training configuration
Common Patterns
1. Model Training
import torch
from pathlib import Path
def save_training_artifacts(model, metrics, config, run_dir):
# Create artifact directories
artifact_dir = Path('/workspace/artifacts') / run_dir
model_dir = artifact_dir / 'models'
metrics_dir = artifact_dir / 'metrics'
config_dir = artifact_dir / 'config'
# Create directories
for d in [model_dir, metrics_dir, config_dir]:
d.mkdir(parents=True, exist_ok=True)
# Save model
torch.save(model.state_dict(), model_dir / 'model.pt')
# Save metrics
with open(metrics_dir / 'metrics.json', 'w') as f:
json.dump(metrics, f)
# Save config
with open(config_dir / 'config.json', 'w') as f:
json.dump(config, f)
2. Experiment Tracking
import wandb
from pathlib import Path
def track_experiment(model, config, artifacts_dir):
run = wandb.init(project="my-project", config=config)
# Log model artifacts
model_path = Path(artifacts_dir) / 'models' / 'model.pt'
wandb.save(str(model_path))
# Log metrics
wandb.log({
"train_loss": train_loss,
"val_loss": val_loss,
"learning_rate": config["lr"]
})
3. Distributed Training
import torch.distributed as dist
from pathlib import Path
def save_distributed_checkpoint(model, optimizer, epoch, rank):
if rank == 0: # Save only on main process
checkpoint_dir = Path('/workspace/artifacts/checkpoints')
checkpoint_dir.mkdir(parents=True, exist_ok=True)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_dir / f'checkpoint_epoch_{epoch}.pt')
dist.barrier() # Wait for save to complete
Troubleshooting
Common Issues
-
Storage Space
# Check storage usage wave storage usage j-xyz789 # Clean up old artifacts wave storage cleanup --older-than 7d
-
Missing Artifacts
# Verify artifact paths wave storage list j-xyz789 --verbose # Check job logs wave jobs logs j-xyz789 | grep "saving"
-
Download Issues
# Retry with verbose output wave storage download j-xyz789 \ --output ./results \ --verbose
Support
- Storage issues: storage@trainwave.ai
- Technical support: support@trainwave.ai
- Documentation: docs@trainwave.ai