//===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "llvm/ADT/SmallVector.h" #include "gmock/gmock.h" using namespace ::mlir; using namespace ::testing; TEST(DynamicMemRef, rankZero) { int data = 57; StridedMemRefType memRef; memRef.basePtr = &data; memRef.data = &data; memRef.offset = 0; DynamicMemRefType dynamicMemRef(memRef); llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); EXPECT_THAT(values, ElementsAre(57)); } TEST(DynamicMemRef, rankOne) { std::array data; for (size_t i = 0; i < data.size(); ++i) { data[i] = i; } StridedMemRefType memRef; memRef.basePtr = data.data(); memRef.data = data.data(); memRef.offset = 0; memRef.sizes[0] = 3; memRef.strides[0] = 1; DynamicMemRefType dynamicMemRef(memRef); llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); EXPECT_THAT(values, ElementsAreArray(data)); for (int64_t i = 0; i < 3; ++i) { EXPECT_EQ(*dynamicMemRef[i], data[i]); } } TEST(DynamicMemRef, rankTwo) { std::array data; for (size_t i = 0; i < data.size(); ++i) { data[i] = i; } StridedMemRefType memRef; memRef.basePtr = data.data(); memRef.data = data.data(); memRef.offset = 0; memRef.sizes[0] = 2; memRef.sizes[1] = 3; memRef.strides[0] = 3; memRef.strides[1] = 1; DynamicMemRefType dynamicMemRef(memRef); llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); EXPECT_THAT(values, ElementsAreArray(data)); } TEST(DynamicMemRef, rankThree) { std::array data; for (size_t i = 0; i < data.size(); ++i) { data[i] = i; } StridedMemRefType memRef; memRef.basePtr = data.data(); memRef.data = data.data(); memRef.offset = 0; memRef.sizes[0] = 2; memRef.sizes[1] = 3; memRef.sizes[2] = 4; memRef.strides[0] = 12; memRef.strides[1] = 4; memRef.strides[2] = 1; DynamicMemRefType dynamicMemRef(memRef); llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); EXPECT_THAT(values, ElementsAreArray(data)); } TEST(DynamicMemRef, rankOneWithOffset) { constexpr int offset = 4; std::array buffer; for (size_t i = 0; i < buffer.size(); ++i) { buffer[i] = i; } StridedMemRefType memRef; memRef.basePtr = buffer.data(); memRef.data = buffer.data(); memRef.offset = offset; memRef.sizes[0] = 3; memRef.strides[0] = 1; DynamicMemRefType dynamicMemRef(memRef); llvm::SmallVector values(dynamicMemRef.begin(), dynamicMemRef.end()); for (int64_t i = 0; i < 3; ++i) { EXPECT_EQ(values[i], buffer[offset + i]); EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]); } }