Line data Source code
1 : #pragma once
2 :
3 : #include <c10/util/Exception.h>
4 : #include <c10/util/MaybeOwned.h>
5 : #include <atomic>
6 : #include <climits>
7 : #include <memory>
8 :
9 : namespace pybind11 {
10 : template <typename, typename...>
11 : class class_;
12 : }
13 :
14 : namespace c10 {
15 : class intrusive_ptr_target;
16 : namespace raw {
17 : namespace weak_intrusive_ptr {
18 : inline void incref(intrusive_ptr_target* self);
19 : }
20 : namespace intrusive_ptr {
21 : inline void incref(intrusive_ptr_target* self);
22 : }
23 :
24 : // constructor tag used by intrusive_ptr constructors
25 : struct DontIncreaseRefcount {};
26 : } // namespace raw
27 : /**
28 : * intrusive_ptr<T> is an alternative to shared_ptr<T> that has better
29 : * performance because it does the refcounting intrusively
30 : * (i.e. in a member of the object itself).
31 : * Your class T needs to inherit from intrusive_ptr_target to allow it to be
32 : * used in an intrusive_ptr<T>. Your class's constructor should not allow
33 : *`this` to escape to other threads or create an intrusive_ptr from `this`.
34 : */
35 :
36 : // Note [Stack allocated intrusive_ptr_target safety]
37 : // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
38 : // A well known problem with std::enable_shared_from_this is that it
39 : // allows you to create a std::shared_ptr from a stack allocated object,
40 : // which is totally bogus because the object will die once you return
41 : // from the stack. In intrusive_ptr, we can detect that this has occurred,
42 : // because we set the refcount/weakcount of objects which inherit from
43 : // intrusive_ptr_target to zero, *unless* we can prove that the object
44 : // was dynamically allocated (e.g., via make_intrusive).
45 : //
46 : // Thus, whenever you transmute a T* into a intrusive_ptr<T>, we check
47 : // and make sure that the refcount isn't zero (or, a more subtle
48 : // test for weak_intrusive_ptr<T>, for which the refcount may validly
49 : // be zero, but the weak refcount better not be zero), because that
50 : // tells us if the object was allocated by us. If it wasn't, no
51 : // intrusive_ptr for you!
52 :
53 : // NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor)
54 : class C10_API intrusive_ptr_target {
55 : // Note [Weak references for intrusive refcounting]
56 : // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57 : // Here's the scheme:
58 : //
59 : // - refcount == number of strong references to the object
60 : // weakcount == number of weak references to the object,
61 : // plus one more if refcount > 0
62 : // An invariant: refcount > 0 => weakcount > 0
63 : //
64 : // - c10::StorageImpl stays live as long as there are any strong
65 : // or weak pointers to it (weakcount > 0, since strong
66 : // references count as a +1 to weakcount)
67 : //
68 : // - finalizers are called and data_ptr is deallocated when refcount == 0
69 : //
70 : // - Once refcount == 0, it can never again be > 0 (the transition
71 : // from > 0 to == 0 is monotonic)
72 : //
73 : // - When you access c10::StorageImpl via a weak pointer, you must
74 : // atomically increment the use count, if it is greater than 0.
75 : // If it is not, you must report that the storage is dead.
76 : //
77 : mutable std::atomic<size_t> refcount_;
78 : mutable std::atomic<size_t> weakcount_;
79 :
80 : template <typename T, typename NullType>
81 : friend class intrusive_ptr;
82 : friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self);
83 :
84 : template <typename T, typename NullType>
85 : friend class weak_intrusive_ptr;
86 : friend inline void raw::weak_intrusive_ptr::incref(
87 : intrusive_ptr_target* self);
88 :
89 : template <typename T>
90 : friend struct ExclusivelyOwnedTensorTraits;
91 :
92 : protected:
93 : // protected destructor. We never want to destruct intrusive_ptr_target*
94 : // directly.
95 520215158 : virtual ~intrusive_ptr_target() {
96 : // Disable -Wterminate and -Wexceptions so we're allowed to use assertions
97 : // (i.e. throw exceptions) in a destructor.
98 : // We also have to disable -Wunknown-warning-option and -Wpragmas, because
99 : // some other compilers don't know about -Wterminate or -Wexceptions and
100 : // will show a warning about unknown warning options otherwise.
101 : #if defined(_MSC_VER) && !defined(__clang__)
102 : #pragma warning(push)
103 : #pragma warning( \
104 : disable : 4297) // function assumed not to throw an exception but does
105 : #else
106 : #pragma GCC diagnostic push
107 : #pragma GCC diagnostic ignored "-Wpragmas"
108 : #pragma GCC diagnostic ignored "-Wunknown-warning-option"
109 : #pragma GCC diagnostic ignored "-Wterminate"
110 : #pragma GCC diagnostic ignored "-Wexceptions"
111 : #endif
112 1040430316 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
113 : // Second condition is there to accommodate
114 : // unsafe_adapt_non_heap_allocated: since we are doing our own
115 : // deallocation in that case, it is correct for each
116 : // expected_decref to have happened (some user code tried to
117 : // decref and thus free the object, but it didn't happen right
118 : // away) or not (no user code tried to free the object, and
119 : // now it's getting destroyed through whatever mechanism the
120 : // caller of unsafe_adapt_non_heap_allocated wanted to
121 : // use). We choose our reference count such that the count
122 : // will not dip below INT_MAX regardless.
123 : refcount_.load() == 0 || refcount_.load() >= INT_MAX,
124 : "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ",
125 : refcount_.load());
126 1040430316 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
127 : // See ~intrusive_ptr for optimization that will frequently result in 1
128 : // at destruction time.
129 : weakcount_.load() == 1 || weakcount_.load() == 0 ||
130 : weakcount_.load() == INT_MAX - 1 || weakcount_.load() == INT_MAX,
131 : "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it");
132 : #if defined(_MSC_VER) && !defined(__clang__)
133 : #pragma warning(pop)
134 : #else
135 : #pragma GCC diagnostic pop
136 : #endif
137 520215158 : }
138 :
139 378768 : constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {}
140 :
141 : // intrusive_ptr_target supports copy and move: but refcount and weakcount
142 : // don't participate (since they are intrinsic properties of the memory
143 : // location)
144 : intrusive_ptr_target(intrusive_ptr_target&& /*other*/) noexcept
145 : : intrusive_ptr_target() {}
146 :
147 : intrusive_ptr_target& operator=(intrusive_ptr_target&& /*other*/) noexcept {
148 : return *this;
149 : }
150 :
151 : intrusive_ptr_target(const intrusive_ptr_target& /*other*/) noexcept
152 : : intrusive_ptr_target() {}
153 :
154 : intrusive_ptr_target& operator=(
155 : const intrusive_ptr_target& /*other*/) noexcept {
156 : return *this;
157 : }
158 :
159 : private:
160 : /**
161 : * This is called when refcount reaches zero.
162 : * You can override this to release expensive resources.
163 : * There might still be weak references, so your object might not get
164 : * destructed yet, but you can assume the object isn't used anymore,
165 : * i.e. no more calls to methods or accesses to members (we just can't
166 : * destruct it yet because we need the weakcount accessible).
167 : *
168 : * If there are no weak references (i.e. your class is about to be
169 : * destructed), this function WILL NOT be called.
170 : */
171 0 : virtual void release_resources() {}
172 : };
173 :
174 : namespace detail {
175 : template <class TTarget>
176 : struct intrusive_target_default_null_type final {
177 149296920 : static constexpr TTarget* singleton() noexcept {
178 149296920 : return nullptr;
179 : }
180 : };
181 :
182 : template <class TTarget, class ToNullType, class FromNullType>
183 378768 : TTarget* assign_ptr_(TTarget* rhs) {
184 378768 : if (FromNullType::singleton() == rhs) {
185 0 : return ToNullType::singleton();
186 : } else {
187 378768 : return rhs;
188 : }
189 : }
190 :
191 : // Increment needs to be acquire-release to make use_count() and
192 : // unique() reliable.
193 440756 : inline size_t atomic_refcount_increment(std::atomic<size_t>& refcount) {
194 881512 : return refcount.fetch_add(1, std::memory_order_acq_rel) + 1;
195 : }
196 :
197 : // weak_use_count() is only used for testing, so we don't need it to
198 : // be reliable. Relaxed should be fine.
199 : inline size_t atomic_weakcount_increment(std::atomic<size_t>& weakcount) {
200 : return weakcount.fetch_add(1, std::memory_order_relaxed) + 1;
201 : }
202 :
203 : // Both decrements need to be acquire-release for correctness. See
204 : // e.g. std::shared_ptr implementation.
205 212684736 : inline size_t atomic_refcount_decrement(std::atomic<size_t>& refcount) {
206 425369472 : return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
207 : }
208 :
209 0 : inline size_t atomic_weakcount_decrement(std::atomic<size_t>& weakcount) {
210 0 : return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
211 : }
212 :
213 : } // namespace detail
214 :
215 : template <class TTarget, class NullType>
216 : class weak_intrusive_ptr;
217 :
218 : template <
219 : class TTarget,
220 : class NullType = detail::intrusive_target_default_null_type<TTarget>>
221 : class intrusive_ptr final {
222 : private:
223 : // the following static assert would be nice to have but it requires
224 : // the target class T to be fully defined when intrusive_ptr<T> is instantiated
225 : // this is a problem for classes that contain pointers to themselves
226 : // static_assert(
227 : // std::is_base_of<intrusive_ptr_target, TTarget>::value,
228 : // "intrusive_ptr can only be used for classes that inherit from
229 : // intrusive_ptr_target.");
230 : #ifndef _WIN32
231 : // This static_assert triggers on MSVC
232 : // error C2131: expression did not evaluate to a constant
233 : static_assert(
234 : NullType::singleton() == NullType::singleton(),
235 : "NullType must have a constexpr singleton() method");
236 : #endif
237 : static_assert(
238 : std::is_base_of<
239 : TTarget,
240 : typename std::remove_pointer<decltype(NullType::singleton())>::type>::
241 : value,
242 : "NullType::singleton() must return a element_type* pointer");
243 :
244 : TTarget* target_;
245 :
246 : template <typename T>
247 : friend struct ExclusivelyOwnedTensorTraits;
248 : template <class TTarget2, class NullType2>
249 : friend class intrusive_ptr;
250 : friend class weak_intrusive_ptr<TTarget, NullType>;
251 :
252 : // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom
253 : // smart holder in pybind11 could access the private constructor of
254 : // intrusive_ptr(T*) which took the ownership of the object. This is required
255 : // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses
256 : // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For
257 : // details, see
258 : // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers
259 : template <typename, typename...>
260 : friend class pybind11::class_;
261 :
262 441642 : void retain_() {
263 441642 : if (target_ != NullType::singleton()) {
264 : size_t new_refcount =
265 440756 : detail::atomic_refcount_increment(target_->refcount_);
266 440756 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
267 : new_refcount != 1,
268 : "intrusive_ptr: Cannot increase refcount after it reached zero.");
269 : }
270 441642 : }
271 :
272 710681430 : void reset_() noexcept {
273 923366166 : if (target_ != NullType::singleton() &&
274 212684736 : detail::atomic_refcount_decrement(target_->refcount_) == 0) {
275 : // See comment above about weakcount. As long as refcount>0,
276 : // weakcount is one larger than the actual number of weak references.
277 : // So we need to decrement it here.
278 211864418 : bool should_delete =
279 211864418 : target_->weakcount_.load(std::memory_order_acquire) == 1;
280 211864418 : if (!should_delete) {
281 : // justification for const_cast: release_resources is basically a
282 : // destructor and a destructor always mutates the object, even for const
283 : // objects. NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
284 0 : const_cast<std::remove_const_t<TTarget>*>(target_)->release_resources();
285 0 : should_delete =
286 0 : detail::atomic_weakcount_decrement(target_->weakcount_) == 0;
287 : }
288 211864418 : if (should_delete) {
289 211864418 : delete target_;
290 : }
291 : }
292 710681430 : }
293 :
294 : // raw pointer constructors are not public because we shouldn't make
295 : // intrusive_ptr out of raw pointers except from inside the make_intrusive(),
296 : // reclaim() and weak_intrusive_ptr::lock() implementations.
297 :
298 : // This constructor will increase the ref counter for you.
299 : // This constructor will be used by the make_intrusive(), and also pybind11,
300 : // which wrap the intrusive_ptr holder around the raw pointer and incref
301 : // correspondingly (pybind11 requires raw pointer constructor to incref by
302 : // default).
303 37304192 : explicit intrusive_ptr(TTarget* target)
304 37304192 : : intrusive_ptr(target, raw::DontIncreaseRefcount{}) {
305 37304192 : if (target_ != NullType::singleton()) {
306 : // We just created result.target_, so we know no other thread has
307 : // access to it, so we know we needn't care about memory ordering.
308 : // (On x86_64, a store with memory_order_relaxed generates a plain old
309 : // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is
310 : // much more expensive: https://godbolt.org/z/eKPzj8.)
311 37304192 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
312 : target_->refcount_ == 0 && target_->weakcount_ == 0,
313 : "intrusive_ptr: Newly-created target had non-zero refcounts. Does its "
314 : "constructor do something strange like incref or create an "
315 : "intrusive_ptr from `this`?");
316 37304192 : target_->refcount_.store(1, std::memory_order_relaxed);
317 37304192 : target_->weakcount_.store(1, std::memory_order_relaxed);
318 : }
319 37304192 : }
320 :
321 : public:
322 : using element_type = TTarget;
323 :
324 37370160 : intrusive_ptr() noexcept
325 37370160 : : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
326 :
327 : intrusive_ptr(std::nullptr_t) noexcept
328 : : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
329 :
330 : // This constructor will not increase the ref counter for you.
331 : // We use the tagged dispatch mechanism to explicitly mark this constructor
332 : // to not increase the refcount
333 179950728 : explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
334 179950728 : : target_(target) {}
335 :
336 : explicit intrusive_ptr(std::unique_ptr<TTarget> rhs) noexcept
337 : : intrusive_ptr(rhs.release()) {}
338 :
339 772197404 : intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
340 772197404 : rhs.target_ = NullType::singleton();
341 772197404 : }
342 :
343 : template <class From, class FromNullType>
344 : /* implicit */ intrusive_ptr(intrusive_ptr<From, FromNullType>&& rhs) noexcept
345 : : target_(
346 : detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
347 : static_assert(
348 : std::is_convertible<From*, TTarget*>::value,
349 : "Type mismatch. intrusive_ptr move constructor got pointer of wrong type.");
350 : rhs.target_ = FromNullType::singleton();
351 : }
352 :
353 62874 : intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) {
354 62874 : retain_();
355 62874 : }
356 :
357 : template <class From, class FromNullType>
358 378768 : /* implicit */ intrusive_ptr(const intrusive_ptr<From, FromNullType>& rhs)
359 378768 : : target_(
360 378768 : detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
361 : static_assert(
362 : std::is_convertible<From*, TTarget*>::value,
363 : "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type.");
364 378768 : retain_();
365 378768 : }
366 :
367 710681430 : ~intrusive_ptr() noexcept {
368 710681430 : reset_();
369 710681430 : }
370 :
371 435080 : intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept {
372 435080 : return operator=<TTarget, NullType>(std::move(rhs));
373 : }
374 :
375 : template <class From, class FromNullType>
376 435080 : intrusive_ptr& operator=(intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
377 : static_assert(
378 : std::is_convertible<From*, TTarget*>::value,
379 : "Type mismatch. intrusive_ptr move assignment got pointer of wrong type.");
380 435080 : intrusive_ptr tmp = std::move(rhs);
381 435080 : swap(tmp);
382 870160 : return *this;
383 435080 : }
384 :
385 1820 : intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept {
386 1820 : return operator=<TTarget, NullType>(rhs);
387 : }
388 :
389 : template <class From, class FromNullType>
390 1820 : intrusive_ptr& operator=(const intrusive_ptr<From, NullType>& rhs) & {
391 : static_assert(
392 : std::is_convertible<From*, TTarget*>::value,
393 : "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type.");
394 1820 : intrusive_ptr tmp = rhs;
395 1820 : swap(tmp);
396 1820 : return *this;
397 1820 : }
398 :
399 105695220 : TTarget* get() const noexcept {
400 105695220 : return target_;
401 : }
402 :
403 0 : TTarget& operator*() const noexcept {
404 0 : return *target_;
405 : }
406 :
407 121964308 : TTarget* operator->() const noexcept {
408 : // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
409 121964308 : return target_;
410 : }
411 :
412 466880 : operator bool() const noexcept {
413 466880 : return target_ != NullType::singleton();
414 : }
415 :
416 : void reset() noexcept {
417 : reset_();
418 : target_ = NullType::singleton();
419 : }
420 :
421 436900 : void swap(intrusive_ptr& rhs) noexcept {
422 436900 : TTarget* tmp = target_;
423 436900 : target_ = rhs.target_;
424 436900 : rhs.target_ = tmp;
425 436900 : }
426 :
427 : // We do a lot of null-pointer checks in our code, good to have this be cheap.
428 : bool defined() const noexcept {
429 : return target_ != NullType::singleton();
430 : }
431 :
432 458920 : size_t use_count() const noexcept {
433 458920 : if (target_ == NullType::singleton()) {
434 0 : return 0;
435 : }
436 917840 : return target_->refcount_.load(std::memory_order_acquire);
437 : }
438 :
439 : size_t weak_use_count() const noexcept {
440 : if (target_ == NullType::singleton()) {
441 : return 0;
442 : }
443 : return target_->weakcount_.load(std::memory_order_acquire);
444 : }
445 :
446 : bool unique() const noexcept {
447 : return use_count() == 1;
448 : }
449 :
450 : /**
451 : * Returns an owning (!) pointer to the underlying object and makes the
452 : * intrusive_ptr instance invalid. That means the refcount is not decreased.
453 : * You *must* put the returned pointer back into a intrusive_ptr using
454 : * intrusive_ptr::reclaim(ptr) to properly destruct it.
455 : * This is helpful for C APIs.
456 : */
457 : TTarget* release() noexcept {
458 : // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
459 : TTarget* result = target_;
460 : target_ = NullType::singleton();
461 : return result;
462 : }
463 :
464 : /**
465 : * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes
466 : * over ownership. That means the refcount is not increased.
467 : * This is the counter-part to intrusive_ptr::release() and the pointer
468 : * passed in *must* have been created using intrusive_ptr::release().
469 : */
470 105276376 : static intrusive_ptr reclaim(TTarget* owning_ptr) {
471 315829128 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
472 : owning_ptr == NullType::singleton() ||
473 : owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(),
474 : "TTarget violates the invariant that refcount > 0 => weakcount > 0");
475 105276376 : return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{});
476 : }
477 :
478 : /**
479 : * Takes an owning pointer to TTarget* and creates an intrusive_ptr
480 : * representing a new reference, i.e. the raw pointer retains
481 : * ownership.
482 : */
483 : static intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
484 : auto ret = reclaim(owning_ptr);
485 : ret.retain_();
486 : return ret;
487 : }
488 :
489 : /**
490 : * Allocate a heap object with args and wrap it inside a intrusive_ptr and
491 : * incref. This is a helper function to let make_intrusive() access private
492 : * intrusive_ptr constructors.
493 : */
494 : template <class... Args>
495 37304192 : static intrusive_ptr make(Args&&... args) {
496 37304192 : return intrusive_ptr(new TTarget(std::forward<Args>(args)...));
497 : }
498 :
499 : /**
500 : * Turn a new instance of TTarget (e.g., literally allocated
501 : * using new TTarget(...) into an intrusive_ptr. If possible,
502 : * use intrusive_ptr::make instead which statically guarantees
503 : * that the allocation was done properly.
504 : *
505 : * At the moment, the only reason this method exists is because
506 : * pybind11 holder types expect to be able to allocate in
507 : * this way (because pybind11 handles the new allocation itself).
508 : */
509 : static intrusive_ptr unsafe_steal_from_new(TTarget* raw_ptr) {
510 : return intrusive_ptr(raw_ptr);
511 : }
512 :
513 : /**
514 : * Turn an instance of TTarget that should not be reference counted
515 : * (e.g., allocated into an arena with placement new) into an
516 : * intrusive_ptr. This is gratuitously unsafe and should only be
517 : * used if you can guarantee that the pointer will not escape and be
518 : * refcounted as normal.
519 : *
520 : * `expected_decrefs` is a debugging parameter: it indicates the
521 : * number of strong owners the intrusive_ptr_target in question is
522 : * expected to get. In most use cases, this will likely be 1.
523 : *
524 : * The reason this method exists is for manually sharing
525 : * StorageImpls across Tensors in the static runtime. It needs
526 : * access to private intrusive_ptr members so that the refcounts can
527 : * be initialized to custom values.
528 : */
529 : static intrusive_ptr unsafe_adapt_non_heap_allocated(
530 : TTarget* raw_ptr,
531 : size_t expected_decrefs) {
532 : intrusive_ptr result(raw_ptr, raw::DontIncreaseRefcount{});
533 : // INT_MAX is impractically huge for a reference count, while
534 : // being in no danger of overflowing size_t. We actually only need to
535 : // initialize the refcount to 2 -- we are just doing an unbalanced
536 : // incref to prevent the non-heap-allocated target from being
537 : // freed, and we are optimizing that incref by directly
538 : // initializing the refcounts rather than doing an expensive
539 : // atomic increment. The reason to use INT_MAX is to accommodate
540 : // the debug assertions in ~intrusive_ptr_target.
541 : #ifdef NDEBUG
542 : expected_decrefs = 0;
543 : #endif
544 : result.target_->refcount_.store(
545 : INT_MAX + expected_decrefs, std::memory_order_relaxed);
546 : result.target_->weakcount_.store(INT_MAX, std::memory_order_relaxed);
547 : return result;
548 : }
549 :
550 : /**
551 : * Turn a **non-owning raw pointer** to an intrusive_ptr. It is
552 : * the moral equivalent of enable_shared_from_this on a shared pointer.
553 : *
554 : * This method is only valid for objects that are already live. If
555 : * you are looking for the moral equivalent of unique_ptr<T>(T*)
556 : * constructor, see steal_from_new.
557 : *
558 : * TODO: https://github.com/pytorch/pytorch/issues/56482
559 : */
560 : static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) {
561 : // See Note [Stack allocated intrusive_ptr_target safety]
562 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
563 : raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0,
564 : "intrusive_ptr: Can only reclaim pointers that are owned by someone");
565 : auto ptr = reclaim(raw_ptr); // doesn't increase refcount
566 : ptr.retain_();
567 : return ptr;
568 : }
569 : };
570 :
571 : template <
572 : class TTarget,
573 : class NullType = detail::intrusive_target_default_null_type<TTarget>,
574 : class... Args>
575 37304192 : inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) {
576 37304192 : return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...);
577 : }
578 :
579 : template <class TTarget, class NullType>
580 : inline void swap(
581 : intrusive_ptr<TTarget, NullType>& lhs,
582 : intrusive_ptr<TTarget, NullType>& rhs) noexcept {
583 : lhs.swap(rhs);
584 : }
585 :
586 : // To allow intrusive_ptr inside std::map or std::set, we need operator<
587 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
588 : inline bool operator<(
589 : const intrusive_ptr<TTarget1, NullType1>& lhs,
590 : const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
591 : return lhs.get() < rhs.get();
592 : }
593 :
594 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
595 : inline bool operator==(
596 : const intrusive_ptr<TTarget1, NullType1>& lhs,
597 : const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
598 : return lhs.get() == rhs.get();
599 : }
600 :
601 : template <class TTarget1, class NullType1>
602 : inline bool operator==(
603 : const intrusive_ptr<TTarget1, NullType1>& lhs,
604 : std::nullptr_t) noexcept {
605 : return lhs.get() == nullptr;
606 : }
607 :
608 : template <class TTarget2, class NullType2>
609 : inline bool operator==(
610 : std::nullptr_t,
611 : const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
612 : return nullptr == rhs.get();
613 : }
614 :
615 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
616 : inline bool operator!=(
617 : const intrusive_ptr<TTarget1, NullType1>& lhs,
618 : const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
619 : return !operator==(lhs, rhs);
620 : }
621 :
622 : template <class TTarget1, class NullType1>
623 : inline bool operator!=(
624 : const intrusive_ptr<TTarget1, NullType1>& lhs,
625 : std::nullptr_t) noexcept {
626 : return !operator==(lhs, nullptr);
627 : }
628 :
629 : template <class TTarget2, class NullType2>
630 : inline bool operator!=(
631 : std::nullptr_t,
632 : const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
633 : return !operator==(nullptr, rhs);
634 : }
635 : template <typename T>
636 : struct MaybeOwnedTraits<c10::intrusive_ptr<T>> {
637 : using owned_type = c10::intrusive_ptr<T>;
638 : using borrow_type = c10::intrusive_ptr<T>;
639 :
640 : static borrow_type createBorrow(const owned_type& from) {
641 : return borrow_type::reclaim(from.get());
642 : }
643 :
644 : static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
645 : lhs.release();
646 : lhs = borrow_type::reclaim(rhs.get());
647 : }
648 :
649 : static void destroyBorrow(borrow_type& toDestroy) {
650 : toDestroy.release();
651 : }
652 :
653 : static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
654 : return borrow;
655 : }
656 :
657 : static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
658 : return &borrow;
659 : }
660 :
661 : static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
662 : return true;
663 : }
664 : };
665 :
666 : template <
667 : typename TTarget,
668 : class NullType = detail::intrusive_target_default_null_type<TTarget>>
669 : class weak_intrusive_ptr final {
670 : private:
671 : static_assert(
672 : std::is_base_of<intrusive_ptr_target, TTarget>::value,
673 : "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target.");
674 : #ifndef _WIN32
675 : // This static_assert triggers on MSVC
676 : // error C2131: expression did not evaluate to a constant
677 : static_assert(
678 : NullType::singleton() == NullType::singleton(),
679 : "NullType must have a constexpr singleton() method");
680 : #endif
681 : static_assert(
682 : std::is_base_of<
683 : TTarget,
684 : typename std::remove_pointer<decltype(NullType::singleton())>::type>::
685 : value,
686 : "NullType::singleton() must return a element_type* pointer");
687 :
688 : TTarget* target_;
689 :
690 : template <class TTarget2, class NullType2>
691 : friend class weak_intrusive_ptr;
692 :
693 : void retain_() {
694 : if (target_ != NullType::singleton()) {
695 : size_t new_weakcount =
696 : detail::atomic_weakcount_increment(target_->weakcount_);
697 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
698 : new_weakcount != 1,
699 : "weak_intrusive_ptr: Cannot increase weakcount after it reached zero.");
700 : }
701 : }
702 :
703 : void reset_() noexcept {
704 : if (target_ != NullType::singleton() &&
705 : detail::atomic_weakcount_decrement(target_->weakcount_) == 0) {
706 : // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
707 : delete target_;
708 : }
709 : target_ = NullType::singleton();
710 : }
711 :
712 : constexpr explicit weak_intrusive_ptr(TTarget* target) : target_(target) {}
713 :
714 : public:
715 : using element_type = TTarget;
716 :
717 : explicit weak_intrusive_ptr(const intrusive_ptr<TTarget, NullType>& ptr)
718 : : weak_intrusive_ptr(ptr.get()) {
719 : retain_();
720 : }
721 :
722 : weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
723 : rhs.target_ = NullType::singleton();
724 : }
725 :
726 : template <class From, class FromNullType>
727 : /* implicit */ weak_intrusive_ptr(
728 : weak_intrusive_ptr<From, FromNullType>&& rhs) noexcept
729 : : target_(
730 : detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
731 : static_assert(
732 : std::is_convertible<From*, TTarget*>::value,
733 : "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type.");
734 : rhs.target_ = FromNullType::singleton();
735 : }
736 :
737 : weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) {
738 : retain_();
739 : }
740 :
741 : template <class From, class FromNullType>
742 : /* implicit */ weak_intrusive_ptr(
743 : const weak_intrusive_ptr<From, FromNullType>& rhs)
744 : : target_(
745 : detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
746 : static_assert(
747 : std::is_convertible<From*, TTarget*>::value,
748 : "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type.");
749 : retain_();
750 : }
751 :
752 : ~weak_intrusive_ptr() noexcept {
753 : reset_();
754 : }
755 :
756 : weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept {
757 : return operator=<TTarget, NullType>(std::move(rhs));
758 : }
759 :
760 : template <class From, class FromNullType>
761 : weak_intrusive_ptr& operator=(
762 : weak_intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
763 : static_assert(
764 : std::is_convertible<From*, TTarget*>::value,
765 : "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type.");
766 : weak_intrusive_ptr tmp = std::move(rhs);
767 : swap(tmp);
768 : return *this;
769 : }
770 :
771 : weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept {
772 : return operator=<TTarget, NullType>(rhs);
773 : }
774 :
775 : weak_intrusive_ptr& operator=(
776 : const intrusive_ptr<TTarget, NullType>& rhs) & noexcept {
777 : weak_intrusive_ptr tmp(rhs);
778 : swap(tmp);
779 : return *this;
780 : }
781 :
782 : template <class From, class FromNullType>
783 : weak_intrusive_ptr& operator=(
784 : const weak_intrusive_ptr<From, NullType>& rhs) & {
785 : static_assert(
786 : std::is_convertible<From*, TTarget*>::value,
787 : "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type.");
788 : weak_intrusive_ptr tmp = rhs;
789 : swap(tmp);
790 : return *this;
791 : }
792 :
793 : void reset() noexcept {
794 : reset_();
795 : }
796 :
797 : void swap(weak_intrusive_ptr& rhs) noexcept {
798 : TTarget* tmp = target_;
799 : target_ = rhs.target_;
800 : rhs.target_ = tmp;
801 : }
802 :
803 : // NB: This should ONLY be used by the std::hash implementation
804 : // for weak_intrusive_ptr. Another way you could do this is
805 : // friend std::hash<weak_intrusive_ptr>, but this triggers two
806 : // bugs:
807 : //
808 : // (1) It triggers an nvcc bug, where std::hash in a friend class
809 : // declaration gets preprocessed into hash, which then cannot
810 : // actually be found. The error in this case looks like:
811 : //
812 : // error: no template named 'hash'; did you mean 'std::hash'?
813 : //
814 : // (2) On OS X, std::hash is declared as a struct, not a class.
815 : // This twings:
816 : //
817 : // error: class 'hash' was previously declared as a struct
818 : // [-Werror,-Wmismatched-tags]
819 : //
820 : // Both of these are work-aroundable, but on the whole, I decided
821 : // it would be simpler and easier to make work if we just expose
822 : // an unsafe getter for target_
823 : //
824 : TTarget* _unsafe_get_target() const noexcept {
825 : return target_;
826 : }
827 :
828 : size_t use_count() const noexcept {
829 : if (target_ == NullType::singleton()) {
830 : return 0;
831 : }
832 : return target_->refcount_.load(
833 : std::memory_order_acquire); // refcount, not weakcount!
834 : }
835 :
836 : size_t weak_use_count() const noexcept {
837 : if (target_ == NullType::singleton()) {
838 : return 0;
839 : }
840 : return target_->weakcount_.load(std::memory_order_acquire);
841 : }
842 :
843 : bool expired() const noexcept {
844 : return use_count() == 0;
845 : }
846 :
847 : intrusive_ptr<TTarget, NullType> lock() const noexcept {
848 : if (expired()) {
849 : return intrusive_ptr<TTarget, NullType>();
850 : } else {
851 : auto refcount = target_->refcount_.load(std::memory_order_seq_cst);
852 : do {
853 : if (refcount == 0) {
854 : // Object already destructed, no strong references left anymore.
855 : // Return nullptr.
856 : return intrusive_ptr<TTarget, NullType>();
857 : }
858 : } while (
859 : !target_->refcount_.compare_exchange_weak(refcount, refcount + 1));
860 : return intrusive_ptr<TTarget, NullType>(
861 : target_, raw::DontIncreaseRefcount{});
862 : }
863 : }
864 :
865 : /**
866 : * Returns an owning (but still only weakly referenced) pointer to the
867 : * underlying object and makes the weak_intrusive_ptr instance invalid.
868 : * That means the weakcount is not decreased.
869 : * You *must* put the returned pointer back into a weak_intrusive_ptr using
870 : * weak_intrusive_ptr::reclaim(ptr) to properly destruct it.
871 : * This is helpful for C APIs.
872 : */
873 : TTarget* release() noexcept {
874 : TTarget* result = target_;
875 : target_ = NullType::singleton();
876 : return result;
877 : }
878 :
879 : /**
880 : * Takes an owning (but must be weakly referenced) pointer to TTarget* and
881 : * creates a weak_intrusive_ptr that takes over ownership.
882 : * This means that the weakcount is not increased.
883 : * This is the counter-part to weak_intrusive_ptr::release() and the pointer
884 : * passed in *must* have been created using weak_intrusive_ptr::release().
885 : */
886 : static weak_intrusive_ptr reclaim(TTarget* owning_weak_ptr) {
887 : // See Note [Stack allocated intrusive_ptr_target safety]
888 : // if refcount > 0, weakcount must be >1 for weak references to exist.
889 : // see weak counting explanation at top of this file.
890 : // if refcount == 0, weakcount only must be >0.
891 : TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
892 : owning_weak_ptr == NullType::singleton() ||
893 : owning_weak_ptr->weakcount_.load() > 1 ||
894 : (owning_weak_ptr->refcount_.load() == 0 &&
895 : owning_weak_ptr->weakcount_.load() > 0),
896 : "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release().");
897 : return weak_intrusive_ptr(owning_weak_ptr);
898 : }
899 :
900 : /**
901 : * Takes a pointer to TTarget* (may be weak or strong) and creates a
902 : * new weak_intrusive_ptr representing a new weak reference, i.e.
903 : * the raw pointer retains ownership.
904 : */
905 : static weak_intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
906 : auto ret = reclaim(owning_ptr);
907 : ret.retain_();
908 : return ret;
909 : }
910 :
911 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
912 : friend bool operator<(
913 : const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
914 : const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
915 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
916 : friend bool operator==(
917 : const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
918 : const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
919 : };
920 :
921 : template <class TTarget, class NullType>
922 : inline void swap(
923 : weak_intrusive_ptr<TTarget, NullType>& lhs,
924 : weak_intrusive_ptr<TTarget, NullType>& rhs) noexcept {
925 : lhs.swap(rhs);
926 : }
927 :
928 : // To allow weak_intrusive_ptr inside std::map or std::set, we need operator<
929 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
930 : inline bool operator<(
931 : const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
932 : const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
933 : return lhs.target_ < rhs.target_;
934 : }
935 :
936 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
937 : inline bool operator==(
938 : const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
939 : const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
940 : return lhs.target_ == rhs.target_;
941 : }
942 :
943 : template <class TTarget1, class NullType1, class TTarget2, class NullType2>
944 : inline bool operator!=(
945 : const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
946 : const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
947 : return !operator==(lhs, rhs);
948 : }
949 :
950 : // Alias for documentary purposes, to more easily distinguish
951 : // weak raw intrusive pointers from intrusive pointers.
952 : using weak_intrusive_ptr_target = intrusive_ptr_target;
953 :
954 : // This namespace provides some methods for working with
955 : // raw pointers that subclass intrusive_ptr_target. They are not provided
956 : // as methods on intrusive_ptr_target, because ideally you would not need these
957 : // methods at all (use smart pointers), but if you are dealing with legacy code
958 : // that still needs to pass around raw pointers, you may find these quite
959 : // useful.
960 : //
961 : // An important usage note: some functions are only valid if you have a
962 : // strong raw pointer to the object, while others are only valid if you
963 : // have a weak raw pointer to the object. ONLY call intrusive_ptr namespace
964 : // functions on strong pointers, and weak_intrusive_ptr namespace functions
965 : // on weak pointers. If you mix it up, you may get an assert failure.
966 : namespace raw {
967 :
968 : namespace intrusive_ptr {
969 :
970 : // WARNING: Unlike the reclaim() API, it is NOT valid to pass
971 : // NullType::singleton to this function
972 0 : inline void incref(intrusive_ptr_target* self) {
973 0 : if (self) {
974 0 : detail::atomic_refcount_increment(self->refcount_);
975 : }
976 0 : }
977 :
978 : // WARNING: Unlike the reclaim() API, it is NOT valid to pass
979 : // NullType::singleton to this function
980 0 : inline void decref(intrusive_ptr_target* self) {
981 : // Let it die
982 0 : c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
983 : // NB: Caller still has 'self' pointer, but it's now invalid.
984 : // If you want more safety, used the actual c10::intrusive_ptr class
985 0 : }
986 :
987 : template <typename T>
988 : inline T* make_weak(T* self) {
989 : // NB: 'this' is a strong pointer, but we return a weak pointer
990 : auto ptr = c10::intrusive_ptr<T>::reclaim(self);
991 : c10::weak_intrusive_ptr<T> wptr(ptr);
992 : ptr.release();
993 : return wptr.release();
994 : }
995 :
996 : inline size_t use_count(intrusive_ptr_target* self) {
997 : auto ptr = c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
998 : auto r = ptr.use_count();
999 : ptr.release();
1000 : return r;
1001 : }
1002 :
1003 : } // namespace intrusive_ptr
1004 :
1005 : namespace weak_intrusive_ptr {
1006 :
1007 : inline void incref(weak_intrusive_ptr_target* self) {
1008 : detail::atomic_weakcount_increment(self->weakcount_);
1009 : }
1010 :
1011 : inline void decref(weak_intrusive_ptr_target* self) {
1012 : // Let it die
1013 : c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1014 : // NB: You still "have" the 'self' pointer, but it's now invalid.
1015 : // If you want more safety, used the actual c10::weak_intrusive_ptr class
1016 : }
1017 :
1018 : template <typename T>
1019 : inline T* lock(T* self) {
1020 : auto wptr = c10::weak_intrusive_ptr<T>::reclaim(self);
1021 : auto ptr = wptr.lock();
1022 : wptr.release();
1023 : return ptr.release();
1024 : }
1025 :
1026 : // This gives the STRONG refcount of a WEAK pointer
1027 : inline size_t use_count(weak_intrusive_ptr_target* self) {
1028 : auto wptr = c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1029 : auto r = wptr.use_count();
1030 : wptr.release();
1031 : return r;
1032 : }
1033 :
1034 : } // namespace weak_intrusive_ptr
1035 :
1036 : } // namespace raw
1037 :
1038 : } // namespace c10
1039 :
1040 : namespace std {
1041 : // To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or
1042 : // std::unordered_set, we need std::hash
1043 : template <class TTarget, class NullType>
1044 : struct hash<c10::intrusive_ptr<TTarget, NullType>> {
1045 : size_t operator()(const c10::intrusive_ptr<TTarget, NullType>& x) const {
1046 : return std::hash<TTarget*>()(x.get());
1047 : }
1048 : };
1049 : template <class TTarget, class NullType>
1050 : struct hash<c10::weak_intrusive_ptr<TTarget, NullType>> {
1051 : size_t operator()(const c10::weak_intrusive_ptr<TTarget, NullType>& x) const {
1052 : return std::hash<TTarget*>()(x._unsafe_get_target());
1053 : }
1054 : };
1055 : } // namespace std
|