Line data Source code
1 : // This file defines OptionalArrayRef<T>, a class that has almost the same
2 : // exact functionality as c10::optional<ArrayRef<T>>, except that its
3 : // converting constructor fixes a dangling pointer issue.
4 : //
5 : // The implicit converting constructor of both c10::optional<ArrayRef<T>> and
6 : // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
7 : // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
8 : // a c10::optional<ArrayRef<T>> and fixing the constructor implementation.
9 : //
10 : // See https://github.com/pytorch/pytorch/issues/63645 for more on this.
11 :
12 : #pragma once
13 :
14 : #include <c10/util/ArrayRef.h>
15 : #include <c10/util/Optional.h>
16 :
17 : namespace c10 {
18 :
19 : template <typename T>
20 : class OptionalArrayRef final {
21 : public:
22 : // Constructors
23 :
24 : constexpr OptionalArrayRef() noexcept = default;
25 :
26 16480 : constexpr OptionalArrayRef(nullopt_t) noexcept {}
27 :
28 : OptionalArrayRef(const OptionalArrayRef& other) = default;
29 :
30 : OptionalArrayRef(OptionalArrayRef&& other) = default;
31 :
32 : constexpr OptionalArrayRef(const optional<ArrayRef<T>>& other) noexcept
33 : : wrapped_opt_array_ref(other) {}
34 :
35 : constexpr OptionalArrayRef(optional<ArrayRef<T>>&& other) noexcept
36 : : wrapped_opt_array_ref(other) {}
37 :
38 14194 : constexpr OptionalArrayRef(const T& value) noexcept
39 14194 : : wrapped_opt_array_ref(value) {}
40 :
41 : template <
42 : typename U = ArrayRef<T>,
43 : std::enable_if_t<
44 : !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
45 : !std::is_same<std::decay_t<U>, in_place_t>::value &&
46 : std::is_constructible<ArrayRef<T>, U&&>::value &&
47 : std::is_convertible<U&&, ArrayRef<T>>::value &&
48 : !std::is_convertible<U&&, T>::value,
49 : bool> = false>
50 : constexpr OptionalArrayRef(U&& value) noexcept(
51 : std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
52 : : wrapped_opt_array_ref(value) {}
53 :
54 : template <
55 : typename U = ArrayRef<T>,
56 : std::enable_if_t<
57 : !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
58 : !std::is_same<std::decay_t<U>, in_place_t>::value &&
59 : std::is_constructible<ArrayRef<T>, U&&>::value &&
60 : !std::is_convertible<U&&, ArrayRef<T>>::value,
61 : bool> = false>
62 : constexpr explicit OptionalArrayRef(U&& value) noexcept(
63 : std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
64 : : wrapped_opt_array_ref(value) {}
65 :
66 : template <typename... Args>
67 : constexpr explicit OptionalArrayRef(in_place_t ip, Args&&... args) noexcept
68 : : wrapped_opt_array_ref(ip, args...) {}
69 :
70 : template <typename U, typename... Args>
71 : constexpr explicit OptionalArrayRef(
72 : in_place_t ip,
73 : std::initializer_list<U> il,
74 : Args&&... args)
75 : : wrapped_opt_array_ref(ip, il, args...) {}
76 :
77 : constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
78 : : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
79 :
80 : // Destructor
81 :
82 : ~OptionalArrayRef() = default;
83 :
84 : // Assignment
85 :
86 : constexpr OptionalArrayRef& operator=(nullopt_t) noexcept {
87 : wrapped_opt_array_ref = c10::nullopt;
88 : return *this;
89 : }
90 :
91 : OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
92 :
93 : OptionalArrayRef& operator=(OptionalArrayRef&& other) = default;
94 :
95 : constexpr OptionalArrayRef& operator=(
96 : const optional<ArrayRef<T>>& other) noexcept {
97 : wrapped_opt_array_ref = other;
98 : return *this;
99 : }
100 :
101 : constexpr OptionalArrayRef& operator=(
102 : optional<ArrayRef<T>>&& other) noexcept {
103 : wrapped_opt_array_ref = other;
104 : return *this;
105 : }
106 :
107 : template <typename U = ArrayRef<T>>
108 : constexpr std::enable_if_t<
109 : !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
110 : std::is_constructible<ArrayRef<T>, U&&>::value &&
111 : std::is_assignable<ArrayRef<T>&, U&&>::value,
112 : OptionalArrayRef&>
113 : operator=(U&& value) noexcept(
114 : std::is_nothrow_constructible<ArrayRef<T>, U&&>::value&&
115 : std::is_nothrow_assignable<ArrayRef<T>&, U&&>::value) {
116 : wrapped_opt_array_ref = value;
117 : return *this;
118 : }
119 :
120 : // Observers
121 :
122 : constexpr ArrayRef<T>* operator->() noexcept {
123 : return &wrapped_opt_array_ref.value();
124 : }
125 :
126 : constexpr const ArrayRef<T>* operator->() const noexcept {
127 : return &wrapped_opt_array_ref.value();
128 : }
129 :
130 : constexpr ArrayRef<T>& operator*() & noexcept {
131 : return wrapped_opt_array_ref.value();
132 : }
133 :
134 : constexpr const ArrayRef<T>& operator*() const& noexcept {
135 : return wrapped_opt_array_ref.value();
136 : }
137 :
138 : constexpr ArrayRef<T>&& operator*() && noexcept {
139 : return std::move(wrapped_opt_array_ref.value());
140 : }
141 :
142 : constexpr const ArrayRef<T>&& operator*() const&& noexcept {
143 : return std::move(wrapped_opt_array_ref.value());
144 : }
145 :
146 : constexpr explicit operator bool() const noexcept {
147 : return wrapped_opt_array_ref.has_value();
148 : }
149 :
150 : constexpr bool has_value() const noexcept {
151 : return wrapped_opt_array_ref.has_value();
152 : }
153 :
154 : constexpr ArrayRef<T>& value() & {
155 : return wrapped_opt_array_ref.value();
156 : }
157 :
158 : constexpr const ArrayRef<T>& value() const& {
159 : return wrapped_opt_array_ref.value();
160 : }
161 :
162 : constexpr ArrayRef<T>&& value() && {
163 : return std::move(wrapped_opt_array_ref.value());
164 : }
165 :
166 : constexpr const ArrayRef<T>&& value() const&& {
167 : return std::move(wrapped_opt_array_ref.value());
168 : }
169 :
170 : template <typename U>
171 : constexpr std::
172 : enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
173 : value_or(U&& default_value) const& {
174 : return wrapped_opt_array_ref.value_or(default_value);
175 : }
176 :
177 : template <typename U>
178 : constexpr std::
179 : enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
180 : value_or(U&& default_value) && {
181 : return wrapped_opt_array_ref.value_or(default_value);
182 : }
183 :
184 : // Modifiers
185 :
186 : constexpr void swap(OptionalArrayRef& other) noexcept {
187 : std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
188 : }
189 :
190 : constexpr void reset() noexcept {
191 : wrapped_opt_array_ref.reset();
192 : }
193 :
194 : template <typename... Args>
195 : constexpr std::enable_if_t<
196 : std::is_constructible<ArrayRef<T>, Args&&...>::value,
197 : ArrayRef<T>&>
198 : emplace(Args&&... args) noexcept(
199 : std::is_nothrow_constructible<ArrayRef<T>, Args&&...>::value) {
200 : return wrapped_opt_array_ref.emplace(args...);
201 : }
202 :
203 : template <typename U, typename... Args>
204 : constexpr ArrayRef<T>& emplace(
205 : std::initializer_list<U> il,
206 : Args&&... args) noexcept {
207 : return wrapped_opt_array_ref.emplace(il, args...);
208 : }
209 :
210 : private:
211 : optional<ArrayRef<T>> wrapped_opt_array_ref;
212 : };
213 :
214 : using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
215 :
216 : inline bool operator==(
217 : const OptionalIntArrayRef& a1,
218 : const IntArrayRef& other) {
219 : if (!a1.has_value()) {
220 : return false;
221 : }
222 : return a1.value() == other;
223 : }
224 :
225 : inline bool operator==(
226 : const c10::IntArrayRef& a1,
227 : const c10::OptionalIntArrayRef& a2) {
228 : return a2 == a1;
229 : }
230 :
231 : } // namespace c10
|