Line data Source code
1 : //===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 :
10 : // ATen: modified from llvm::ArrayRef.
11 : // removed llvm-specific functionality
12 : // removed some implicit const -> non-const conversions that rely on
13 : // complicated std::enable_if meta-programming
14 : // removed a bunch of slice variants for simplicity...
15 :
16 : #pragma once
17 :
18 : #include <c10/util/C++17.h>
19 : #include <c10/util/Deprecated.h>
20 : #include <c10/util/Exception.h>
21 : #include <c10/util/SmallVector.h>
22 :
23 : #include <array>
24 : #include <iterator>
25 : #include <vector>
26 :
27 : namespace c10 {
28 : /// ArrayRef - Represent a constant reference to an array (0 or more elements
29 : /// consecutively in memory), i.e. a start pointer and a length. It allows
30 : /// various APIs to take consecutive elements easily and conveniently.
31 : ///
32 : /// This class does not own the underlying data, it is expected to be used in
33 : /// situations where the data resides in some other buffer, whose lifetime
34 : /// extends past that of the ArrayRef. For this reason, it is not in general
35 : /// safe to store an ArrayRef.
36 : ///
37 : /// This is intended to be trivially copyable, so it should be passed by
38 : /// value.
39 : template <typename T>
40 : class ArrayRef final {
41 : public:
42 : using iterator = const T*;
43 : using const_iterator = const T*;
44 : using size_type = size_t;
45 : using value_type = T;
46 :
47 : using reverse_iterator = std::reverse_iterator<iterator>;
48 :
49 : private:
50 : /// The start of the array, in an external buffer.
51 : const T* Data;
52 :
53 : /// The number of elements.
54 : size_type Length;
55 :
56 106536 : void debugCheckNullptrInvariant() {
57 106536 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
58 : Data != nullptr || Length == 0,
59 : "created ArrayRef with nullptr and non-zero length! c10::optional relies on this being illegal");
60 106536 : }
61 :
62 : public:
63 : /// @name Constructors
64 : /// @{
65 :
66 : /// Construct an empty ArrayRef.
67 : /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
68 :
69 : /// Construct an ArrayRef from a single element.
70 : // TODO Make this explicit
71 90643162 : constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
72 :
73 : /// Construct an ArrayRef from a pointer and length.
74 106536 : C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length)
75 106536 : : Data(data), Length(length) {
76 106536 : debugCheckNullptrInvariant();
77 106536 : }
78 :
79 : /// Construct an ArrayRef from a range.
80 : C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end)
81 : : Data(begin), Length(end - begin) {
82 : debugCheckNullptrInvariant();
83 : }
84 :
85 : /// Construct an ArrayRef from a SmallVector. This is templated in order to
86 : /// avoid instantiating SmallVectorTemplateCommon<T> whenever we
87 : /// copy-construct an ArrayRef.
88 : template <typename U>
89 : /* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
90 : : Data(Vec.data()), Length(Vec.size()) {
91 : debugCheckNullptrInvariant();
92 : }
93 :
94 : template <
95 : typename Container,
96 : typename = std::enable_if_t<std::is_same<
97 : std::remove_const_t<decltype(std::declval<Container>().data())>,
98 : T*>::value>>
99 : /* implicit */ ArrayRef(const Container& container)
100 : : Data(container.data()), Length(container.size()) {
101 : debugCheckNullptrInvariant();
102 : }
103 :
104 : /// Construct an ArrayRef from a std::vector.
105 : // The enable_if stuff here makes sure that this isn't used for
106 : // std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
107 : // bitfield.
108 : template <typename A>
109 415512 : /* implicit */ ArrayRef(const std::vector<T, A>& Vec)
110 415512 : : Data(Vec.data()), Length(Vec.size()) {
111 : static_assert(
112 : !std::is_same<T, bool>::value,
113 : "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
114 415512 : }
115 :
116 : /// Construct an ArrayRef from a std::array
117 : template <size_t N>
118 : /* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
119 : : Data(Arr.data()), Length(N) {}
120 :
121 : /// Construct an ArrayRef from a C array.
122 : template <size_t N>
123 : /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
124 :
125 : /// Construct an ArrayRef from a std::initializer_list.
126 3934 : /* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
127 3934 : : Data(
128 3934 : std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
129 3934 : : std::begin(Vec)),
130 3934 : Length(Vec.size()) {}
131 :
132 : /// @}
133 : /// @name Simple Operations
134 : /// @{
135 :
136 17786 : constexpr iterator begin() const {
137 17786 : return Data;
138 : }
139 17786 : constexpr iterator end() const {
140 17786 : return Data + Length;
141 : }
142 :
143 : // These are actually the same as iterator, since ArrayRef only
144 : // gives you const iterators.
145 : constexpr const_iterator cbegin() const {
146 : return Data;
147 : }
148 : constexpr const_iterator cend() const {
149 : return Data + Length;
150 : }
151 :
152 : constexpr reverse_iterator rbegin() const {
153 : return reverse_iterator(end());
154 : }
155 : constexpr reverse_iterator rend() const {
156 : return reverse_iterator(begin());
157 : }
158 :
159 : /// empty - Check if the array is empty.
160 : constexpr bool empty() const {
161 : return Length == 0;
162 : }
163 :
164 17910 : constexpr const T* data() const {
165 17910 : return Data;
166 : }
167 :
168 : /// size - Get the array size.
169 419068 : constexpr size_t size() const {
170 419068 : return Length;
171 : }
172 :
173 : /// front - Get the first element.
174 : C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const {
175 : TORCH_CHECK(
176 : !empty(), "ArrayRef: attempted to access front() of empty list");
177 : return Data[0];
178 : }
179 :
180 : /// back - Get the last element.
181 : C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const {
182 : TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
183 : return Data[Length - 1];
184 : }
185 :
186 : /// equals - Check for element-wise equality.
187 : constexpr bool equals(ArrayRef RHS) const {
188 : return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
189 : }
190 :
191 : /// slice(n, m) - Take M elements of the array starting at element N
192 : C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M)
193 : const {
194 : TORCH_CHECK(
195 : N + M <= size(),
196 : "ArrayRef: invalid slice, N = ",
197 : N,
198 : "; M = ",
199 : M,
200 : "; size = ",
201 : size());
202 : return ArrayRef<T>(data() + N, M);
203 : }
204 :
205 : /// slice(n) - Chop off the first N elements of the array.
206 : C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N) const {
207 : TORCH_CHECK(
208 : N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
209 : return slice(N, size() - N);
210 : }
211 :
212 : /// @}
213 : /// @name Operator Overloads
214 : /// @{
215 88590 : constexpr const T& operator[](size_t Index) const {
216 88590 : return Data[Index];
217 : }
218 :
219 : /// Vector compatibility
220 : C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const {
221 : TORCH_CHECK(
222 : Index < Length,
223 : "ArrayRef: invalid index Index = ",
224 : Index,
225 : "; Length = ",
226 : Length);
227 : return Data[Index];
228 : }
229 :
230 : /// Disallow accidental assignment from a temporary.
231 : ///
232 : /// The declaration here is extra complicated so that "arrayRef = {}"
233 : /// continues to select the move assignment operator.
234 : template <typename U>
235 : typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
236 : operator=(U&& Temporary) = delete;
237 :
238 : /// Disallow accidental assignment from a temporary.
239 : ///
240 : /// The declaration here is extra complicated so that "arrayRef = {}"
241 : /// continues to select the move assignment operator.
242 : template <typename U>
243 : typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
244 : operator=(std::initializer_list<U>) = delete;
245 :
246 : /// @}
247 : /// @name Expensive Operations
248 : /// @{
249 : std::vector<T> vec() const {
250 : return std::vector<T>(Data, Data + Length);
251 : }
252 :
253 : /// @}
254 : };
255 :
256 : template <typename T>
257 36 : std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
258 36 : int i = 0;
259 36 : out << "[";
260 108 : for (const auto& e : list) {
261 72 : if (i++ > 0)
262 36 : out << ", ";
263 72 : out << e;
264 : }
265 36 : out << "]";
266 36 : return out;
267 : }
268 :
269 : /// @name ArrayRef Convenience constructors
270 : /// @{
271 :
272 : /// Construct an ArrayRef from a single element.
273 : template <typename T>
274 : ArrayRef<T> makeArrayRef(const T& OneElt) {
275 : return OneElt;
276 : }
277 :
278 : /// Construct an ArrayRef from a pointer and length.
279 : template <typename T>
280 : ArrayRef<T> makeArrayRef(const T* data, size_t length) {
281 : return ArrayRef<T>(data, length);
282 : }
283 :
284 : /// Construct an ArrayRef from a range.
285 : template <typename T>
286 : ArrayRef<T> makeArrayRef(const T* begin, const T* end) {
287 : return ArrayRef<T>(begin, end);
288 : }
289 :
290 : /// Construct an ArrayRef from a SmallVector.
291 : template <typename T>
292 : ArrayRef<T> makeArrayRef(const SmallVectorImpl<T>& Vec) {
293 : return Vec;
294 : }
295 :
296 : /// Construct an ArrayRef from a SmallVector.
297 : template <typename T, unsigned N>
298 : ArrayRef<T> makeArrayRef(const SmallVector<T, N>& Vec) {
299 : return Vec;
300 : }
301 :
302 : /// Construct an ArrayRef from a std::vector.
303 : template <typename T>
304 : ArrayRef<T> makeArrayRef(const std::vector<T>& Vec) {
305 : return Vec;
306 : }
307 :
308 : /// Construct an ArrayRef from a std::array.
309 : template <typename T, std::size_t N>
310 : ArrayRef<T> makeArrayRef(const std::array<T, N>& Arr) {
311 : return Arr;
312 : }
313 :
314 : /// Construct an ArrayRef from an ArrayRef (no-op) (const)
315 : template <typename T>
316 : ArrayRef<T> makeArrayRef(const ArrayRef<T>& Vec) {
317 : return Vec;
318 : }
319 :
320 : /// Construct an ArrayRef from an ArrayRef (no-op)
321 : template <typename T>
322 : ArrayRef<T>& makeArrayRef(ArrayRef<T>& Vec) {
323 : return Vec;
324 : }
325 :
326 : /// Construct an ArrayRef from a C array.
327 : template <typename T, size_t N>
328 : ArrayRef<T> makeArrayRef(const T (&Arr)[N]) {
329 : return ArrayRef<T>(Arr);
330 : }
331 :
332 : // WARNING: Template instantiation will NOT be willing to do an implicit
333 : // conversions to get you to an c10::ArrayRef, which is why we need so
334 : // many overloads.
335 :
336 : template <typename T>
337 : bool operator==(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
338 : return a1.equals(a2);
339 : }
340 :
341 : template <typename T>
342 : bool operator!=(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
343 : return !a1.equals(a2);
344 : }
345 :
346 : template <typename T>
347 : bool operator==(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
348 : return c10::ArrayRef<T>(a1).equals(a2);
349 : }
350 :
351 : template <typename T>
352 : bool operator!=(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
353 : return !c10::ArrayRef<T>(a1).equals(a2);
354 : }
355 :
356 : template <typename T>
357 : bool operator==(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
358 : return a1.equals(c10::ArrayRef<T>(a2));
359 : }
360 :
361 : template <typename T>
362 : bool operator!=(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
363 : return !a1.equals(c10::ArrayRef<T>(a2));
364 : }
365 :
366 : using IntArrayRef = ArrayRef<int64_t>;
367 :
368 : // This alias is deprecated because it doesn't make ownership
369 : // semantics obvious. Use IntArrayRef instead!
370 : C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef<int64_t>)
371 :
372 : } // namespace c10
|