Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

[wip] Grid synchronizations and mapping to blocks. #370

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions tc/autotuner/parameters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,30 +347,42 @@ TuningConfiguration::TuningConfiguration()
useReadOnlyCache("use readonly cache (i.e. emit __ldg loads)"),
matchLibraryCalls("match library calls") {
addValidator([](const TuningConfiguration& conf) {
auto b0v = conf.blockParams.dims.at(0).value();
auto b1v = conf.blockParams.dims.at(1).value();
auto b2v = conf.blockParams.dims.at(2).value();
auto b = conf.blockParams;
auto b0v = b.dims.at(0).value();
auto b1v = b.dims.at(1).value();
auto b2v = b.dims.at(2).value();
auto g = conf.gridParams;
auto g0v = g.dims.at(0).value();
auto g1v = g.dims.at(1).value();
auto g2v = g.dims.at(2).value();
if (b0v <= 0 or b0v > 1024 or b1v <= 0 or b1v > 1024 or b2v <= 0 or
b2v > 64) {
return false;
}
auto blockProduct = [&]() {
switch (conf.blockParams.numberDims.value()) {
auto computeProduct = [&](const CudaDimParameters& p) {
switch (p.numberDims.value()) {
case 3:
return b0v * b1v * b2v;
return p.dims.at(0).value() * p.dims.at(1).value() *
p.dims.at(2).value();
case 2:
return b0v * b1v;
return p.dims.at(0).value() * p.dims.at(1).value();
case 1:
return b0v;
return p.dims.at(0).value();
default:
TC_CHECK(false) << "Must have (1-3) block dims, got: "
<< conf.blockParams.numberDims.value();
}
return b0v;
}();
return p.dims.at(0).value();
};
auto blockProduct = computeProduct(b);
auto gridProduct = computeProduct(g);
if (blockProduct < 32 or blockProduct > 512) {
return false;
}
if (FLAGS_reduce_launch_size and
(gridProduct > 128 or blockProduct > 256)) {
return false;
}
return true;
});
}
Expand Down
7 changes: 7 additions & 0 deletions tc/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,15 @@ target_link_libraries(
tc_version
tc_proto
)

if (WITH_BINDINGS)
add_dependencies(tc_core generate_isl_cpp_h)
endif()

if(WITH_CUDA)
target_link_libraries(tc_cuda_version)
endif()

install(
TARGETS
tc_core
Expand Down Expand Up @@ -176,6 +182,7 @@ if (WITH_CUDA)

tc_lang
tc_version
tc_cuda_version
tc_proto
tc_core
)
Expand Down
1 change: 1 addition & 0 deletions tc/core/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ constexpr auto kReadIdName = "read";
constexpr auto kWriteIdName = "write";
constexpr auto kSyncIdPrefix = "_sync_";
constexpr auto kWarpSyncIdPrefix = "_warpSync_";
constexpr auto kGridSyncIdPrefix = "_gridSync_";

} // namespace polyhedral
} // namespace tc
71 changes: 67 additions & 4 deletions tc/core/cuda/cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ DEFINE_bool(use_nvprof, false, "Start / stop nvprof");

namespace {

std::tuple<std::vector<std::string>, std::vector<size_t>> init() {
std::tuple<
std::vector<std::string>,
std::vector<size_t>,
std::vector<size_t>,
std::vector<size_t>,
std::vector<size_t>,
std::vector<size_t>>
init() {
int deviceCount = 0;
auto err_id = cudaGetDeviceCount(&deviceCount);
if (err_id == 35 or err_id == 30) {
Expand All @@ -44,14 +51,36 @@ std::tuple<std::vector<std::string>, std::vector<size_t>> init() {
}
std::vector<std::string> gpuNames;
std::vector<size_t> sharedMemSizes;
std::vector<size_t> sharedMemSizesPerSM;
std::vector<size_t> blocksPerSM;
std::vector<size_t> threadsPerSM;
std::vector<size_t> nbOfSM;
gpuNames.reserve(deviceCount);
for (int i = 0; i < deviceCount; ++i) {
cudaDeviceProp deviceProp;
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDeviceProperties(&deviceProp, i));
gpuNames.emplace_back(deviceProp.name);
sharedMemSizes.emplace_back(deviceProp.sharedMemPerBlock);
sharedMemSizesPerSM.emplace_back(deviceProp.sharedMemPerMultiprocessor);

// There is currently no way to get the number of blocks per sm
// with the CUDA api. The only relevant solution is to compute it
// with the compute capability.
// the formula works if the number of blocks per sm is nondecreasing after
// the 6.0 compute capability.
auto major = deviceProp.major;
blocksPerSM.emplace_back(major < 3 ? 8 : (major < 4 ? 16 : 32));

threadsPerSM.emplace_back(deviceProp.maxThreadsPerMultiProcessor);
nbOfSM.emplace_back(deviceProp.multiProcessorCount);
}
return std::make_tuple(gpuNames, sharedMemSizes);
return std::make_tuple(
gpuNames,
sharedMemSizes,
sharedMemSizesPerSM,
blocksPerSM,
threadsPerSM,
nbOfSM);
}

} // namespace
Expand All @@ -61,8 +90,13 @@ CudaGPUInfo& CudaGPUInfo::GPUInfo() {
static thread_local bool inited = false;
if (!inited) {
auto infos = init();
pInfo = std::unique_ptr<CudaGPUInfo>(
new CudaGPUInfo(std::get<0>(infos), std::get<1>(infos)));
pInfo = std::unique_ptr<CudaGPUInfo>(new CudaGPUInfo(
std::get<0>(infos),
std::get<1>(infos),
std::get<2>(infos),
std::get<3>(infos),
std::get<4>(infos),
std::get<5>(infos)));
inited = true;
}
return *pInfo;
Expand Down Expand Up @@ -102,4 +136,33 @@ size_t CudaGPUInfo::SharedMemorySize() const {
}
return sharedMemSizes_.at(CurrentGPUId());
}

size_t CudaGPUInfo::SharedMemorySizePerSM() const {
if (NumberGPUs() == 0) {
return 0; // no shared memory per sm if no GPUs
}
return sharedMemSizesPerSM_.at(CurrentGPUId());
}

size_t CudaGPUInfo::BlocksPerSM() const {
if (NumberGPUs() == 0) {
return 0; // no blocks per sm if no GPUs
}
return blocksPerSM_.at(CurrentGPUId());
}

size_t CudaGPUInfo::ThreadsPerSM() const {
if (NumberGPUs() == 0) {
return 0; // no threads per sm if no GPUs
}
return threadsPerSM_.at(CurrentGPUId());
}

size_t CudaGPUInfo::NbOfSM() const {
if (NumberGPUs() == 0) {
return 0; // no sm if no GPUs
}
return nbOfSM_.at(CurrentGPUId());
}

} // namespace tc
21 changes: 19 additions & 2 deletions tc/core/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,17 @@ struct WithCudaDevice {
class CudaGPUInfo {
CudaGPUInfo(
const std::vector<std::string>& gpuNames,
const std::vector<size_t>& sharedMemSizes)
: gpuNames_(gpuNames), sharedMemSizes_(sharedMemSizes) {}
const std::vector<size_t>& sharedMemSizes,
const std::vector<size_t>& sharedMemSizesPerSM,
const std::vector<size_t>& blocksPerSM,
const std::vector<size_t>& threadsPerSM,
const std::vector<size_t>& nbOfSM)
: gpuNames_(gpuNames),
sharedMemSizes_(sharedMemSizes),
sharedMemSizesPerSM_(sharedMemSizesPerSM),
blocksPerSM_(blocksPerSM),
threadsPerSM_(threadsPerSM),
nbOfSM_(nbOfSM) {}

public:
static CudaGPUInfo& GPUInfo();
Expand All @@ -110,9 +119,17 @@ class CudaGPUInfo {
std::string GetGPUName(int id = -1) const;
std::string getCudaDeviceStr() const;
size_t SharedMemorySize() const;
size_t SharedMemorySizePerSM() const;
size_t BlocksPerSM() const;
size_t ThreadsPerSM() const;
size_t NbOfSM() const;

std::vector<std::string> gpuNames_;
std::vector<size_t> sharedMemSizes_;
std::vector<size_t> sharedMemSizesPerSM_;
std::vector<size_t> blocksPerSM_;
std::vector<size_t> threadsPerSM_;
std::vector<size_t> nbOfSM_;
};

struct CudaProfiler {
Expand Down
1 change: 1 addition & 0 deletions tc/core/cuda/cuda_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct CudaCompilationResult {
std::vector<long> parameters;
Grid grid;
Block block;
bool useGridSync;
};

/**
Expand Down
6 changes: 6 additions & 0 deletions tc/core/cuda/cuda_libraries.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ __device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
#endif
)C";

constexpr auto gridSyncFunctions = R"C(
__device__ void __syncgrid() {
cudaCGSynchronize(cudaCGGetIntrinsicHandle(cudaCGScopeGrid),0);
}
)C";

constexpr auto mathFunctionDecl = R"C(

// BEGIN MATH FUNCTIONS FROM CUDA
Expand Down
70 changes: 55 additions & 15 deletions tc/core/cuda/cuda_rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "tc/core/cuda/cuda_rtc.h"
#include "tc/core/flags.h"
#include "tc/core/scope_guard.h"
#include "tc/version/cuda_version.h"

namespace tc {
std::mutex nvrtc_mutex;
Expand All @@ -50,7 +51,8 @@ void CudaRTCFunction::clear() {

std::unique_ptr<CudaRTCFunction> CudaRTCFunction::Compile(
const std::string& name,
const std::string& source) {
const std::string& source,
bool useGridSync) {
std::unique_ptr<CudaRTCFunction> res(new CudaRTCFunction());
res->specializedName = name;
res->cleared_ = false;
Expand Down Expand Up @@ -88,6 +90,9 @@ std::unique_ptr<CudaRTCFunction> CudaRTCFunction::Compile(
"-DNVRTC_CUB=1",
cudaHome.c_str(),
cubHome.c_str()};
if (useGridSync) {
nvrtcts.push_back("--relocatable-device-code=true");
}
if (FLAGS_debug_cuda) {
nvrtcts.push_back(nvrtc_debug_opts[0]);
nvrtcts.push_back(nvrtc_debug_opts[1]);
Expand Down Expand Up @@ -132,6 +137,7 @@ std::ostream& operator<<(std::ostream& os, const std::array<T, 3>& a) {
Duration CudaRTCFunction::Launch(
const std::array<size_t, 3>& grid,
const std::array<size_t, 3>& block,
bool useGridSync,
unsigned int shared_mem,
cudaStream_t stream,
std::vector<long> params,
Expand All @@ -143,8 +149,28 @@ Duration CudaRTCFunction::Launch(
if (perGpuModule_.count(dev) == 0) {
CUmodule module;
CUfunction function;
TC_CUDA_DRIVERAPI_ENFORCE(
cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0));
if (useGridSync) {
CUlinkState linkState;
TC_CUDA_DRIVERAPI_ENFORCE(cuLinkCreate(0, 0, 0, &linkState));
TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddFile(
linkState, CU_JIT_INPUT_LIBRARY, cuda_libdevrt_path, 0, 0, 0));
TC_CUDA_DRIVERAPI_ENFORCE(cuLinkAddData(
linkState,
CU_JIT_INPUT_PTX,
(void*)nvrtc_ptx.data(),
nvrtc_ptx.size(),
"device_code.ptx",
0,
0,
0));
size_t cubinSize;
void* cubin;
TC_CUDA_DRIVERAPI_ENFORCE(cuLinkComplete(linkState, &cubin, &cubinSize));
TC_CUDA_DRIVERAPI_ENFORCE(cuModuleLoadData(&module, cubin));
} else {
TC_CUDA_DRIVERAPI_ENFORCE(
cuModuleLoadDataEx(&module, nvrtc_ptx.data(), 0, 0, 0));
}
perGpuModule_.emplace(dev, module);
TC_CUDA_DRIVERAPI_ENFORCE(
cuModuleGetFunction(&function, module, specializedName.c_str()));
Expand Down Expand Up @@ -174,18 +200,32 @@ Duration CudaRTCFunction::Launch(
unsigned int by = block[1];
unsigned int bz = block[2];
auto launch = [&]() {
TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
perGpuKernel_.at(dev),
gx,
gy,
gz,
bx,
by,
bz,
shared_mem,
stream,
args_voidp.data(),
0));
if (useGridSync) {
TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchCooperativeKernel(
perGpuKernel_.at(dev),
gx,
gy,
gz,
bx,
by,
bz,
shared_mem,
stream,
args_voidp.data()));
} else {
TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
perGpuKernel_.at(dev),
gx,
gy,
gz,
bx,
by,
bz,
shared_mem,
stream,
args_voidp.data(),
0));
}
};

if (not profile) {
Expand Down
6 changes: 3 additions & 3 deletions tc/core/cuda/cuda_rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ class CudaRTCFunction {
public:
~CudaRTCFunction();

static std::unique_ptr<CudaRTCFunction> Compile(
const std::string& name,
const std::string& source);
static std::unique_ptr<CudaRTCFunction>
Compile(const std::string& name, const std::string& source, bool useGridSync);

// if profile is set it returns the kernel runtime
Duration Launch(
const std::array<size_t, 3>& grid,
const std::array<size_t, 3>& block,
bool useGridSync,
unsigned int shared_mem,
cudaStream_t stream,
// by copy because we take an address to element when calling the kernel
Expand Down
Loading