11#ifndef CUBBYFLOW_CUDA_ARRAY_BASE_IMPL_HPP
12#define CUBBYFLOW_CUDA_ARRAY_BASE_IMPL_HPP
14#ifdef CUBBYFLOW_USE_CUDA
18template <
typename T,
size_t N,
typename Derived>
19size_t CUDAArrayBase<T, N, Derived>::Index(
size_t i)
const
24template <
typename T,
size_t N,
typename Derived>
25template <
typename...
Args>
26size_t CUDAArrayBase<T, N, Derived>::Index(
size_t i,
Args...
args)
const
28 static_assert(
sizeof...(args) ==
N - 1,
"Invalid number of indices.");
29 return i + m_size[0] * IndexInternal(1,
args...);
32template <
typename T,
size_t N,
typename Derived>
34size_t CUDAArrayBase<T, N, Derived>::Index(
37 return IndexInternal(
idx, std::make_index_sequence<N>{});
40template <
typename T,
size_t N,
typename Derived>
46template <
typename T,
size_t N,
typename Derived>
52template <
typename T,
size_t N,
typename Derived>
58template <
typename T,
size_t N,
typename Derived>
60std::enable_if_t<(
M > 0),
size_t> CUDAArrayBase<T, N, Derived>::Width()
const
65template <
typename T,
size_t N,
typename Derived>
67std::enable_if_t<(
M > 1),
size_t> CUDAArrayBase<T, N, Derived>::Height()
const
72template <
typename T,
size_t N,
typename Derived>
74std::enable_if_t<(
M > 2),
size_t> CUDAArrayBase<T, N, Derived>::Depth()
const
79template <
typename T,
size_t N,
typename Derived>
84 for (
size_t i = 1; i <
N; ++i)
93template <
typename T,
size_t N,
typename Derived>
95CUDAArrayBase<T, N, Derived>::At(
size_t i)
100template <
typename T,
size_t N,
typename Derived>
102CUDAArrayBase<T, N, Derived>::At(
size_t i)
const
107template <
typename T,
size_t N,
typename Derived>
108template <
typename...
Args>
110CUDAArrayBase<T, N, Derived>::At(
size_t i,
Args...
args)
112 return At(Index(i,
args...));
115template <
typename T,
size_t N,
typename Derived>
116template <
typename...
Args>
118CUDAArrayBase<T, N, Derived>::At(
size_t i,
Args...
args)
const
120 return At(Index(i,
args...));
123template <
typename T,
size_t N,
typename Derived>
127 return At(Index(
idx));
130template <
typename T,
size_t N,
typename Derived>
134 return At(Index(
idx));
137template <
typename T,
size_t N,
typename Derived>
144template <
typename T,
size_t N,
typename Derived>
151template <
typename T,
size_t N,
typename Derived>
152template <
typename...
Args>
156 return At(i,
args...);
159template <
typename T,
size_t N,
typename Derived>
160template <
typename...
Args>
164 return At(i,
args...);
167template <
typename T,
size_t N,
typename Derived>
174template <
typename T,
size_t N,
typename Derived>
182template <
typename T,
size_t N,
typename Derived>
183typename CUDAArrayBase<T, N, Derived>::HostReference
184CUDAArrayBase<T, N, Derived>::At(
size_t i)
189template <
typename T,
size_t N,
typename Derived>
190T CUDAArrayBase<T, N, Derived>::At(
size_t i)
const
195template <
typename T,
size_t N,
typename Derived>
196template <
typename...
Args>
197typename CUDAArrayBase<T, N, Derived>::HostReference
198CUDAArrayBase<T, N, Derived>::At(
size_t i,
Args...
args)
200 return At(Index(i,
args...));
203template <
typename T,
size_t N,
typename Derived>
204template <
typename...
Args>
205T CUDAArrayBase<T, N, Derived>::At(
size_t i,
Args...
args)
const
207 return At(Index(i,
args...));
210template <
typename T,
size_t N,
typename Derived>
211typename CUDAArrayBase<T, N, Derived>::HostReference
214 return At(Index(
idx));
217template <
typename T,
size_t N,
typename Derived>
220 return At(Index(
idx));
223template <
typename T,
size_t N,
typename Derived>
224typename CUDAArrayBase<T, N, Derived>::HostReference
230template <
typename T,
size_t N,
typename Derived>
236template <
typename T,
size_t N,
typename Derived>
237template <
typename...
Args>
238typename CUDAArrayBase<T, N, Derived>::HostReference
241 return At(i,
args...);
244template <
typename T,
size_t N,
typename Derived>
245template <
typename...
Args>
248 return At(i,
args...);
251template <
typename T,
size_t N,
typename Derived>
252typename CUDAArrayBase<T, N, Derived>::HostReference
258template <
typename T,
size_t N,
typename Derived>
266template <
typename T,
size_t N,
typename Derived>
267CUDAArrayBase<T, N, Derived>::CUDAArrayBase() : m_size{}
272template <
typename T,
size_t N,
typename Derived>
273CUDAArrayBase<T, N, Derived>::CUDAArrayBase(
const CUDAArrayBase& other)
275 SetPtrAndSize(other.m_ptr, other.m_size);
278template <
typename T,
size_t N,
typename Derived>
279CUDAArrayBase<T, N, Derived>::CUDAArrayBase(CUDAArrayBase&& other)
noexcept
281 *
this = std::move(other);
284template <
typename T,
size_t N,
typename Derived>
285CUDAArrayBase<T, N, Derived>& CUDAArrayBase<T, N, Derived>::operator=(
286 const CUDAArrayBase& other)
288 SetPtrAndSize(other.m_ptr, other.m_size);
292template <
typename T,
size_t N,
typename Derived>
293CUDAArrayBase<T, N, Derived>& CUDAArrayBase<T, N, Derived>::operator=(
294 CUDAArrayBase&& other)
noexcept
296 SetPtrAndSize(other.m_ptr, other.m_size);
297 other.SetPtrAndSize(
nullptr, CUDAStdArray<size_t, N>{});
301template <
typename T,
size_t N,
typename Derived>
302template <
typename... Args>
303void CUDAArrayBase<T, N, Derived>::SetPtrAndSize(Pointer ptr,
size_t ni,
306 SetPtrAndSize(ptr, CUDAStdArray<size_t, N>{ ni, args... });
309template <
typename T,
size_t N,
typename Derived>
310void CUDAArrayBase<T, N, Derived>::SetPtrAndSize(Pointer ptr,
311 CUDAStdArray<size_t, N> size)
317template <
typename T,
size_t N,
typename Derived>
318void CUDAArrayBase<T, N, Derived>::SwapPtrAndSize(CUDAArrayBase& other)
320 CUDASwap(m_ptr, other.m_ptr);
321 CUDASwap(m_size, other.m_size);
324template <
typename T,
size_t N,
typename Derived>
325void CUDAArrayBase<T, N, Derived>::ClearPtrAndSize()
327 SetPtrAndSize(
nullptr, CUDAStdArray<size_t, N>{});
330template <
typename T,
size_t N,
typename Derived>
331template <
typename... Args>
332size_t CUDAArrayBase<T, N, Derived>::IndexInternal(
size_t d,
size_t i,
335 return i + m_size[d] * IndexInternal(d + 1, args...);
338template <
typename T,
size_t N,
typename Derived>
339size_t CUDAArrayBase<T, N, Derived>::IndexInternal(
size_t,
size_t i)
const
344template <
typename T,
size_t N,
typename Derived>
345template <
size_t... I>
346size_t CUDAArrayBase<T, N, Derived>::IndexInternal(
347 const CUDAStdArray<size_t, N>& idx, std::index_sequence<I...>)
const
349 return Index(idx[I]...);
Reference operator()(size_t i, size_t j)
Definition MatrixDenseBase-Impl.hpp:107
ValueType Length() const
Definition MatrixExpression-Impl.hpp:278
T & Reference
Definition Matrix.hpp:40
Pointer data()
Definition Matrix-Impl.hpp:298
Reference operator[](size_t i)
Definition Matrix-Impl.hpp:311
const T & ConstReference
Definition Matrix.hpp:41
Definition pybind11Utils.hpp:21
Matrix< T, Rows, 1 > Vector
Definition Matrix.hpp:738