Skip to main content

code-style

代码风格

#define FOREACH_BUFFER_TORCH_TYPE_MAP(F)                                                                               \
F(TYPE_UINT8, torch::kByte) \
F(TYPE_INT8, torch::kChar) \
F(TYPE_INT16, torch::kShort) \
F(TYPE_INT32, torch::kInt) \
F(TYPE_INT64, torch::kLong) \
F(TYPE_FP16, torch::kHalf) \
F(TYPE_FP32, torch::kFloat) \
F(TYPE_FP64, torch::kDouble) \
F(TYPE_BOOL, torch::kBool) \
F(TYPE_BF16, torch::kBFloat16) \
F(TYPE_FP8_E4M3, TORCH_FP8_E4M3_TYPE)

inline DataType torchDTypeToDataType(caffe2::TypeMeta dtype) {
#define TYPE_CASE(type, torch_type) \
case torch_type: { \
return type; \
}

switch (dtype.toScalarType()) {
FOREACH_BUFFER_TORCH_TYPE_MAP(TYPE_CASE);
default:
printStackTrace();
RTP_LLM_LOG_ERROR("Unsupported data type: [%d]%s", dtype.toScalarType(), dtype.name().data());
throw std::runtime_error("Unsupported data type " + std::to_string((int8_t)(dtype.toScalarType())));
}

#undef TYPE_CASE
}