//===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Support/LLVM.h" using namespace llvm; using namespace mlir::sparse_tensor; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, mlir::sparse_tensor::SparseTensorDialect) // Ensure the C-API enums are int-castable to C++ equivalents. static_assert(static_cast(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == static_cast(LevelType::Dense) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == static_cast(LevelType::Compressed) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == static_cast(LevelType::CompressedNu) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == static_cast(LevelType::CompressedNo) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == static_cast(LevelType::CompressedNuNo) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == static_cast(LevelType::Singleton) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == static_cast(LevelType::SingletonNu) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == static_cast(LevelType::SingletonNo) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == static_cast(LevelType::SingletonNuNo), "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return isa(unwrap(attr)); } MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, MlirAffineMap lvlToDim, int posWidth, int crdWidth) { SmallVector cppLvlTypes; cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) cppLvlTypes.push_back(static_cast(lvlTypes[l])); return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth, crdWidth)); } MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { return wrap(cast(unwrap(attr)).getDimToLvl()); } MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { return wrap(cast(unwrap(attr)).getLvlToDim()); } intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { return cast(unwrap(attr)).getLvlRank(); } MlirSparseTensorLevelType mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { return static_cast( cast(unwrap(attr)).getLvlType(lvl)); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { return cast(unwrap(attr)).getPosWidth(); } int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { return cast(unwrap(attr)).getCrdWidth(); }