LCOV - code coverage report
Current view: top level - libtorch/include/c10/util - OptionalArrayRef.h (source / functions) Coverage Total Hit
Test: coverage.info Lines: 100.0 % 3 3
Test Date: 2024-04-30 13:17:26 Functions: 100.0 % 2 2

            Line data    Source code
       1              : // This file defines OptionalArrayRef<T>, a class that has almost the same
       2              : // exact functionality as c10::optional<ArrayRef<T>>, except that its
       3              : // converting constructor fixes a dangling pointer issue.
       4              : //
       5              : // The implicit converting constructor of both c10::optional<ArrayRef<T>> and
       6              : // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
       7              : // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
       8              : // a c10::optional<ArrayRef<T>> and fixing the constructor implementation.
       9              : //
      10              : // See https://github.com/pytorch/pytorch/issues/63645 for more on this.
      11              : 
      12              : #pragma once
      13              : 
      14              : #include <c10/util/ArrayRef.h>
      15              : #include <c10/util/Optional.h>
      16              : 
      17              : namespace c10 {
      18              : 
      19              : template <typename T>
      20              : class OptionalArrayRef final {
      21              :  public:
      22              :   // Constructors
      23              : 
      24              :   constexpr OptionalArrayRef() noexcept = default;
      25              : 
      26        16480 :   constexpr OptionalArrayRef(nullopt_t) noexcept {}
      27              : 
      28              :   OptionalArrayRef(const OptionalArrayRef& other) = default;
      29              : 
      30              :   OptionalArrayRef(OptionalArrayRef&& other) = default;
      31              : 
      32              :   constexpr OptionalArrayRef(const optional<ArrayRef<T>>& other) noexcept
      33              :       : wrapped_opt_array_ref(other) {}
      34              : 
      35              :   constexpr OptionalArrayRef(optional<ArrayRef<T>>&& other) noexcept
      36              :       : wrapped_opt_array_ref(other) {}
      37              : 
      38        14194 :   constexpr OptionalArrayRef(const T& value) noexcept
      39        14194 :       : wrapped_opt_array_ref(value) {}
      40              : 
      41              :   template <
      42              :       typename U = ArrayRef<T>,
      43              :       std::enable_if_t<
      44              :           !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
      45              :               !std::is_same<std::decay_t<U>, in_place_t>::value &&
      46              :               std::is_constructible<ArrayRef<T>, U&&>::value &&
      47              :               std::is_convertible<U&&, ArrayRef<T>>::value &&
      48              :               !std::is_convertible<U&&, T>::value,
      49              :           bool> = false>
      50              :   constexpr OptionalArrayRef(U&& value) noexcept(
      51              :       std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
      52              :       : wrapped_opt_array_ref(value) {}
      53              : 
      54              :   template <
      55              :       typename U = ArrayRef<T>,
      56              :       std::enable_if_t<
      57              :           !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
      58              :               !std::is_same<std::decay_t<U>, in_place_t>::value &&
      59              :               std::is_constructible<ArrayRef<T>, U&&>::value &&
      60              :               !std::is_convertible<U&&, ArrayRef<T>>::value,
      61              :           bool> = false>
      62              :   constexpr explicit OptionalArrayRef(U&& value) noexcept(
      63              :       std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
      64              :       : wrapped_opt_array_ref(value) {}
      65              : 
      66              :   template <typename... Args>
      67              :   constexpr explicit OptionalArrayRef(in_place_t ip, Args&&... args) noexcept
      68              :       : wrapped_opt_array_ref(ip, args...) {}
      69              : 
      70              :   template <typename U, typename... Args>
      71              :   constexpr explicit OptionalArrayRef(
      72              :       in_place_t ip,
      73              :       std::initializer_list<U> il,
      74              :       Args&&... args)
      75              :       : wrapped_opt_array_ref(ip, il, args...) {}
      76              : 
      77              :   constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
      78              :       : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
      79              : 
      80              :   // Destructor
      81              : 
      82              :   ~OptionalArrayRef() = default;
      83              : 
      84              :   // Assignment
      85              : 
      86              :   constexpr OptionalArrayRef& operator=(nullopt_t) noexcept {
      87              :     wrapped_opt_array_ref = c10::nullopt;
      88              :     return *this;
      89              :   }
      90              : 
      91              :   OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
      92              : 
      93              :   OptionalArrayRef& operator=(OptionalArrayRef&& other) = default;
      94              : 
      95              :   constexpr OptionalArrayRef& operator=(
      96              :       const optional<ArrayRef<T>>& other) noexcept {
      97              :     wrapped_opt_array_ref = other;
      98              :     return *this;
      99              :   }
     100              : 
     101              :   constexpr OptionalArrayRef& operator=(
     102              :       optional<ArrayRef<T>>&& other) noexcept {
     103              :     wrapped_opt_array_ref = other;
     104              :     return *this;
     105              :   }
     106              : 
     107              :   template <typename U = ArrayRef<T>>
     108              :   constexpr std::enable_if_t<
     109              :       !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
     110              :           std::is_constructible<ArrayRef<T>, U&&>::value &&
     111              :           std::is_assignable<ArrayRef<T>&, U&&>::value,
     112              :       OptionalArrayRef&>
     113              :   operator=(U&& value) noexcept(
     114              :       std::is_nothrow_constructible<ArrayRef<T>, U&&>::value&&
     115              :           std::is_nothrow_assignable<ArrayRef<T>&, U&&>::value) {
     116              :     wrapped_opt_array_ref = value;
     117              :     return *this;
     118              :   }
     119              : 
     120              :   // Observers
     121              : 
     122              :   constexpr ArrayRef<T>* operator->() noexcept {
     123              :     return &wrapped_opt_array_ref.value();
     124              :   }
     125              : 
     126              :   constexpr const ArrayRef<T>* operator->() const noexcept {
     127              :     return &wrapped_opt_array_ref.value();
     128              :   }
     129              : 
     130              :   constexpr ArrayRef<T>& operator*() & noexcept {
     131              :     return wrapped_opt_array_ref.value();
     132              :   }
     133              : 
     134              :   constexpr const ArrayRef<T>& operator*() const& noexcept {
     135              :     return wrapped_opt_array_ref.value();
     136              :   }
     137              : 
     138              :   constexpr ArrayRef<T>&& operator*() && noexcept {
     139              :     return std::move(wrapped_opt_array_ref.value());
     140              :   }
     141              : 
     142              :   constexpr const ArrayRef<T>&& operator*() const&& noexcept {
     143              :     return std::move(wrapped_opt_array_ref.value());
     144              :   }
     145              : 
     146              :   constexpr explicit operator bool() const noexcept {
     147              :     return wrapped_opt_array_ref.has_value();
     148              :   }
     149              : 
     150              :   constexpr bool has_value() const noexcept {
     151              :     return wrapped_opt_array_ref.has_value();
     152              :   }
     153              : 
     154              :   constexpr ArrayRef<T>& value() & {
     155              :     return wrapped_opt_array_ref.value();
     156              :   }
     157              : 
     158              :   constexpr const ArrayRef<T>& value() const& {
     159              :     return wrapped_opt_array_ref.value();
     160              :   }
     161              : 
     162              :   constexpr ArrayRef<T>&& value() && {
     163              :     return std::move(wrapped_opt_array_ref.value());
     164              :   }
     165              : 
     166              :   constexpr const ArrayRef<T>&& value() const&& {
     167              :     return std::move(wrapped_opt_array_ref.value());
     168              :   }
     169              : 
     170              :   template <typename U>
     171              :   constexpr std::
     172              :       enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
     173              :       value_or(U&& default_value) const& {
     174              :     return wrapped_opt_array_ref.value_or(default_value);
     175              :   }
     176              : 
     177              :   template <typename U>
     178              :   constexpr std::
     179              :       enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
     180              :       value_or(U&& default_value) && {
     181              :     return wrapped_opt_array_ref.value_or(default_value);
     182              :   }
     183              : 
     184              :   // Modifiers
     185              : 
     186              :   constexpr void swap(OptionalArrayRef& other) noexcept {
     187              :     std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
     188              :   }
     189              : 
     190              :   constexpr void reset() noexcept {
     191              :     wrapped_opt_array_ref.reset();
     192              :   }
     193              : 
     194              :   template <typename... Args>
     195              :   constexpr std::enable_if_t<
     196              :       std::is_constructible<ArrayRef<T>, Args&&...>::value,
     197              :       ArrayRef<T>&>
     198              :   emplace(Args&&... args) noexcept(
     199              :       std::is_nothrow_constructible<ArrayRef<T>, Args&&...>::value) {
     200              :     return wrapped_opt_array_ref.emplace(args...);
     201              :   }
     202              : 
     203              :   template <typename U, typename... Args>
     204              :   constexpr ArrayRef<T>& emplace(
     205              :       std::initializer_list<U> il,
     206              :       Args&&... args) noexcept {
     207              :     return wrapped_opt_array_ref.emplace(il, args...);
     208              :   }
     209              : 
     210              :  private:
     211              :   optional<ArrayRef<T>> wrapped_opt_array_ref;
     212              : };
     213              : 
     214              : using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
     215              : 
     216              : inline bool operator==(
     217              :     const OptionalIntArrayRef& a1,
     218              :     const IntArrayRef& other) {
     219              :   if (!a1.has_value()) {
     220              :     return false;
     221              :   }
     222              :   return a1.value() == other;
     223              : }
     224              : 
     225              : inline bool operator==(
     226              :     const c10::IntArrayRef& a1,
     227              :     const c10::OptionalIntArrayRef& a2) {
     228              :   return a2 == a1;
     229              : }
     230              : 
     231              : } // namespace c10
        

Generated by: LCOV version 2.0-1