//===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Support.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" using namespace mlir; MlirAttribute mlirAttributeGetNull() { return {nullptr}; } //===----------------------------------------------------------------------===// // Location attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsALocation(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAAffineMap(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { return wrap(AffineMapAttr::get(unwrap(map))); } MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } MlirTypeID mlirAffineMapAttrGetTypeID(void) { return wrap(AffineMapAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAArray(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements) { SmallVector attrs; return wrap( ArrayAttr::get(unwrap(ctx), unwrapList(static_cast(numElements), elements, attrs))); } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { return static_cast(llvm::cast(unwrap(attr)).size()); } MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsADictionary(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, MlirNamedAttribute const *elements) { SmallVector attributes; attributes.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) attributes.emplace_back(unwrap(elements[i].name), unwrap(elements[i].attribute)); return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { return static_cast(llvm::cast(unwrap(attr)).size()); } MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = llvm::cast(unwrap(attr)).getValue()[pos]; return {wrap(attribute.getName()), wrap(attribute.getValue())}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name) { return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } MlirTypeID mlirDictionaryAttrGetTypeID(void) { return wrap(DictionaryAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAFloat(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, double value) { return wrap(FloatAttr::get(unwrap(type), value)); } MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, double value) { return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value)); } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getValueAsDouble(); } MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAInteger(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { return wrap(IntegerAttr::get(unwrap(type), value)); } int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getInt(); } int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSInt(); } uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getUInt(); } MlirTypeID mlirIntegerAttrGetTypeID(void) { return wrap(IntegerAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsABool(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { return wrap(BoolAttr::get(unwrap(ctx), value)); } bool mlirBoolAttrGetValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getValue(); } //===----------------------------------------------------------------------===// // Integer set attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirTypeID mlirIntegerSetAttrGetTypeID(void) { return wrap(IntegerSetAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAOpaque(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { return wrap( OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), StringRef(data, dataLength), unwrap(type))); } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { return wrap( llvm::cast(unwrap(attr)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getAttrData()); } MlirTypeID mlirOpaqueAttrGetTypeID(void) { return wrap(OpaqueAttr::getTypeID()); } //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAString(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type))); } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } MlirTypeID mlirStringAttrGetTypeID(void) { return wrap(StringAttr::getTypeID()); } //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsASymbolRef(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, intptr_t numReferences, MlirAttribute const *references) { SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(llvm::cast(unwrap(references[i]))); auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); return wrap(SymbolRefAttr::get(symbolAttr, refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { return wrap( llvm::cast(unwrap(attr)).getRootReference().getValue()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { return wrap( llvm::cast(unwrap(attr)).getLeafReference().getValue()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { return static_cast( llvm::cast(unwrap(attr)).getNestedReferences().size()); } MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos) { return wrap( llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } MlirTypeID mlirSymbolRefAttrGetTypeID(void) { return wrap(SymbolRefAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol))); } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAType(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirTypeAttrGet(MlirType type) { return wrap(TypeAttr::get(unwrap(type))); } MlirType mlirTypeAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAUnit(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirUnitAttrGet(MlirContext ctx) { return wrap(UnitAttr::get(unwrap(ctx))); } MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// bool mlirAttributeIsAElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return wrap(llvm::cast(unwrap(attr)) .getValues()[llvm::ArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return llvm::cast(unwrap(attr)) .isValidIndex(llvm::ArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getNumElements(); } //===----------------------------------------------------------------------===// // Dense array attribute. //===----------------------------------------------------------------------===// MlirTypeID mlirDenseArrayAttrGetTypeID() { return wrap(DenseArrayAttr::getTypeID()); } //===----------------------------------------------------------------------===// // IsA support. //===----------------------------------------------------------------------===// bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// // Constructors. //===----------------------------------------------------------------------===// MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values) { SmallVector elements(values, values + size); return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements)); } MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, int8_t const *values) { return wrap( DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); } MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, int16_t const *values) { return wrap( DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); } MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, int32_t const *values) { return wrap( DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); } MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, int64_t const *values) { return wrap( DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); } MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, float const *values) { return wrap( DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); } MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, double const *values) { return wrap( DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); } //===----------------------------------------------------------------------===// // Accessors. //===----------------------------------------------------------------------===// intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { return llvm::cast(unwrap(attr)).size(); } //===----------------------------------------------------------------------===// // Indexed accessors. //===----------------------------------------------------------------------===// bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; } int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; } int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; } int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; } int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; } float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; } double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; } //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // IsA support. //===----------------------------------------------------------------------===// bool mlirAttributeIsADenseElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { return wrap(DenseIntOrFPElementsAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Constructors. //===----------------------------------------------------------------------===// MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, MlirAttribute const *elements) { SmallVector attributes; return wrap( DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrapList(numElements, elements, attributes))); } MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, size_t rawBufferSize, const void *rawBuffer) { auto shapedTypeCpp = llvm::cast(unwrap(shapedType)); ArrayRef rawBufferCpp(static_cast(rawBuffer), rawBufferSize); bool isSplat = false; if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, isSplat)) return mlirAttributeGetNull(); return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp)); } MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrap(element))); } MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, uint64_t element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, int64_t element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, float element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), element)); } MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, intptr_t numElements, const int *elements) { SmallVector values(elements, elements + numElements); return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), values)); } /// Creates a dense attribute with elements of the type deduced by templates. template static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, const T *elements) { return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), llvm::ArrayRef(elements, numElements))); } MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, intptr_t numElements, const uint8_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, intptr_t numElements, const int8_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType, intptr_t numElements, const uint16_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType, intptr_t numElements, const int16_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, const uint32_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, intptr_t numElements, const int32_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, intptr_t numElements, const uint64_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, intptr_t numElements, const int64_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, intptr_t numElements, const float *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, intptr_t numElements, const double *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, intptr_t numElements, const uint16_t *elements) { size_t bufferSize = numElements * 2; const void *buffer = static_cast(elements); return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); } MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, intptr_t numElements, const uint16_t *elements) { size_t bufferSize = numElements * 2; const void *buffer = static_cast(elements); return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); } MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, MlirStringRef *strs) { SmallVector values; values.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { return wrap(llvm::cast(unwrap(attr)) .reshape(llvm::cast(unwrap(shapedType)))); } //===----------------------------------------------------------------------===// // Splat accessors. //===----------------------------------------------------------------------===// bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return llvm::cast(unwrap(attr)).isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { return wrap( llvm::cast(unwrap(attr)).getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getSplatValue(); } MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { return wrap( llvm::cast(unwrap(attr)).getSplatValue()); } //===----------------------------------------------------------------------===// // Indexed accessors. //===----------------------------------------------------------------------===// bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( llvm::cast(unwrap(attr)).getValues()[pos]); } //===----------------------------------------------------------------------===// // Raw data accessors. //===----------------------------------------------------------------------===// const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( llvm::cast(unwrap(attr)).getRawData().data()); } //===----------------------------------------------------------------------===// // Resource blob attributes. //===----------------------------------------------------------------------===// bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, size_t dataAlignment, bool dataIsMutable, void (*deleter)(void *userData, const void *data, size_t size, size_t align), void *userData) { AsmResourceBlob::DeleterFn cppDeleter = {}; if (deleter) { cppDeleter = [deleter, userData](void *data, size_t size, size_t align) { deleter(userData, data, size, align); }; } AsmResourceBlob blob( llvm::ArrayRef(static_cast(data), dataLength), dataAlignment, std::move(cppDeleter), dataIsMutable); return wrap( DenseResourceElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrap(name), std::move(blob))); } template static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, intptr_t numElements, const T *elements) { return wrap(U::get(llvm::cast(unwrap(shapedType)), unwrap(name), UnmanagedAsmResourceBlob::allocateInferAlign( llvm::ArrayRef(elements, numElements)))); } MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const uint8_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const uint16_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const uint32_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const uint64_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int8_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int16_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int32_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int64_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const float *elements) { return getDenseResource(shapedType, name, numElements, elements); } MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const double *elements) { return getDenseResource(shapedType, name, numElements, elements); } template static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { return (*llvm::cast(unwrap(attr)).tryGetAsArrayRef())[pos]; } bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { return getDenseResourceVal(attr, pos); } //===----------------------------------------------------------------------===// // Sparse elements attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsASparseElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, MlirAttribute denseIndices, MlirAttribute denseValues) { return wrap(SparseElementsAttr::get( llvm::cast(unwrap(shapedType)), llvm::cast(unwrap(denseIndices)), llvm::cast(unwrap(denseValues)))); } MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getIndices()); } MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValues()); } MlirTypeID mlirSparseElementsAttrGetTypeID(void) { return wrap(SparseElementsAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, const int64_t *strides) { return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, ArrayRef(strides, numStrides))); } int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getOffset(); } intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { return static_cast( llvm::cast(unwrap(attr)).getStrides().size()); } int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getStrides()[pos]; } MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { return wrap(StridedLayoutAttr::getTypeID()); }