importtorch_memory_savermemory_saver=torch_memory_saver.memory_saver# 1. For tensors that wants to be paused, create them within `region`withmemory_saver.region():pauseable_tensor=torch.full((1_000_000_000,),100,dtype=torch.uint8,device='cuda')# 2. After `pause`, CUDA memory is released for those tensors.# For example, check `nvidia-smi`'s memory usage to verify.memory_saver.pause()# 3. After `resume`, CUDA memory is re-occupied for those tensors.memory_saver.resume()
cudaError_tmalloc(void**ptr,size_tsize,conststd::string&tag){CUdevicedevice;CURESULT_CHECK(cuCtxGetDevice(&device));CUmemGenericAllocationHandleallocHandle;CUDAUtils::cu_mem_create(&allocHandle,size,device);// ptr就是保留重复使用的虚拟地址CURESULT_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr,size,0,0,0));CURESULT_CHECK(cuMemMap((CUdeviceptr)*ptr,size,0,allocHandle,0));CUDAUtils::cu_mem_set_access(*ptr,size,device);{conststd::lock_guard<std::mutex>lock(allocator_metadata_mutex_);allocation_metadata_.emplace(*ptr,_AllocationMetadata{size,device,allocHandle,tag});}}cudaError_tfree(void*ptr){_AllocationMetadatametadata;{conststd::lock_guard<std::mutex>lock(allocator_metadata_mutex_);SIMPLE_CHECK(allocation_metadata_.count(ptr),"Trying to free a pointer not allocated here");metadata=allocation_metadata_[ptr];allocation_metadata_.erase(ptr);}CURESULT_CHECK(cuMemUnmap((CUdeviceptr)ptr,metadata.size));CURESULT_CHECK(cuMemRelease(metadata.allocHandle));CURESULT_CHECK(cuMemAddressFree((CUdeviceptr)ptr,metadata.size));}