2#ifndef GYM_PCB_PYARRAY_H
3#define GYM_PCB_PYARRAY_H
10template<
typename T>
class NPArray
13 static NPArray<T> *create(PyObject *);
14 NPArray() : mPy(0) { }
15 NPArray<T>& operator=(
const NPArray<T> &avoid) { assert(!avoid.mPy); mPy = 0;
return *
this; }
16 NPArray<T>& operator=(NPArray<T> &&move) { own(move.mPy); move.mPy = 0;
return *
this; }
17 NPArray(
const NPArray<T> &avoid) : mPy(0) { assert(!avoid.mPy); }
18 NPArray(NPArray<T> &&move) : mPy(move.mPy) { move.mPy = 0; }
19 NPArray(
const std::vector<T>&);
20 NPArray(
const std::vector<T>&, npy_intp d1, npy_intp d2);
21 NPArray(T *, npy_intp
const *dims, uint num_dims);
22 NPArray(T *, npy_intp d1);
23 NPArray(T *, npy_intp d1, npy_intp d2);
24 NPArray(T *, npy_intp d1, npy_intp d2, npy_intp d3);
25 NPArray(T *, npy_intp d1, npy_intp d2, npy_intp d3, npy_intp d4);
26 NPArray(PyObject *py) : mPy(0) { own(py); Assert(); }
28 bool own(PyArrayObject *);
29 NPArray<T>& ownData();
30 ~NPArray() { own((PyArrayObject *)0); }
31 bool valid()
const {
return mPy; }
32 bool isPacked()
const;
33 void Assert()
const {
if (!valid())
throw std::runtime_error(
"Invalid numpy array"); }
34 PyObject *release() {
auto rv = py(); mPy = 0;
return rv; }
35 PyObject *py()
const {
return reinterpret_cast<PyObject *
>(mPy); }
36 PyArrayObject *pyArray()
const {
return mPy; }
37 T *data()
const {
return static_cast<T *
>(PyArray_DATA(mPy)); }
38 intptr_t count()
const {
return PyArray_SIZE(mPy); }
39 intptr_t stride(uint d)
const {
return PyArray_STRIDE(mPy, d); }
40 intptr_t size()
const {
return PyArray_NBYTES(mPy); }
41 uint dim(uint d)
const {
return PyArray_DIM(mPy, d); }
42 uint degree()
const {
return PyArray_NDIM(mPy); }
43 T operator()(uint x)
const {
return *(T *)PyArray_GETPTR1(mPy, x); }
44 T operator()(uint x, uint y)
const {
return *(T *)PyArray_GETPTR2(mPy, x, y); }
45 T operator()(uint x, uint y, uint z)
const {
return *(T *)PyArray_GETPTR3(mPy, x, y, z); }
46 T operator()(uint x, uint y, uint z, uint i)
const {
return *(T *)PyArray_GETPTR4(mPy, x, y, z, i); }
47 void copyto(std::vector<T>&)
const;
48 std::vector<T> vector()
const { std::vector<T> v; copyto(v);
return v; }
49 static void checkAPI() { assert(PY_ARRAY_UNIQUE_SYMBOL); }
50 static bool checkType(PyArrayObject *);
51 static int classTypenum();
56template<
typename T> NPArray<T>::NPArray(T *data, npy_intp
const *dims, uint num_dims)
59 mPy =
reinterpret_cast<PyArrayObject *
>(PyArray_SimpleNewFromData(num_dims, dims, classTypenum(), data));
61template<
typename T> NPArray<T>::NPArray(
const std::vector<T> &v)
63 const auto size = v.size() *
sizeof(T);
64 const npy_intp dims[1] = { (uint)v.size() };
66 auto data = (T *)malloc(size);
67 std::memcpy(data, v.data(), size);
68 mPy =
reinterpret_cast<PyArrayObject *
>(PyArray_SimpleNewFromData(1, dims, classTypenum(), data));
71template<
typename T> NPArray<T>::NPArray(
const std::vector<T> &v, npy_intp d1, npy_intp d2)
73 auto srcSize = v.size() *
sizeof(T);
74 auto dstSize = d1 * d2 *
sizeof(T);
75 assert(srcSize == dstSize);
76 auto data = (T *)malloc(dstSize);
77 std::memcpy(data, v.data(), std::min(srcSize, dstSize));
78 const npy_intp dims[2] = { d1, d2 };
80 mPy =
reinterpret_cast<PyArrayObject *
>(PyArray_SimpleNewFromData(2, dims, classTypenum(), data));
83template<
typename T> NPArray<T>::NPArray(T *data, npy_intp d1)
86 mPy =
reinterpret_cast<PyArrayObject *
>(PyArray_SimpleNewFromData(1, &d1, classTypenum(), data));
88template<
typename T> NPArray<T>::NPArray(T *data, npy_intp d1, npy_intp d2)
90 const npy_intp dims[2] = { d1, d2 };
92 mPy =
reinterpret_cast<PyArrayObject *
>(PyArray_SimpleNewFromData(2, dims, classTypenum(), data));
94template<
typename T> NPArray<T>::NPArray(T *data, npy_intp d1, npy_intp d2, npy_intp d3)
96 const npy_intp dims[3] = { d1, d2, d3 };
98 mPy =
reinterpret_cast<PyArrayObject *
>(PyArray_SimpleNewFromData(3, dims, classTypenum(), data));
100template<
typename T> NPArray<T>::NPArray(T *data, npy_intp d1, npy_intp d2, npy_intp d3, npy_intp d4)
102 const npy_intp dims[4] = { d1, d2, d3, d4 };
104 mPy =
reinterpret_cast<PyArrayObject *
>(PyArray_SimpleNewFromData(4, dims, classTypenum(), data));
107inline void _capsuleDtor(PyObject *capsule)
109 std::free(PyCapsule_GetPointer(capsule, 0));
111template<
typename T>
NPArray<T>& NPArray<T>::ownData()
113 if (!PyArray_BASE(mPy))
114 PyArray_SetBaseObject(mPy, PyCapsule_New(data(), 0, _capsuleDtor));
118template<
typename T>
NPArray<T> *NPArray<T>::create(PyObject *py)
120 if (py && PyArray_Check(py) && checkType(
reinterpret_cast<PyArrayObject *
>(py)))
125template<
typename T>
bool NPArray<T>::own(PyArrayObject *py)
130 mPy = (py && checkType(py)) ? py : 0;
133template<
typename T>
bool NPArray<T>::own(PyObject *py)
138 mPy = (py && PyArray_Check(py)) ?
reinterpret_cast<PyArrayObject *
>(py) : 0;
139 if (mPy && !checkType(mPy))
144template<
typename T>
int NPArray<T>::classTypenum()
146 if constexpr (std::is_same_v<float, T>)
148 if constexpr (std::is_same_v<double, T>)
150 if constexpr (std::is_same_v<int32_t, T>)
152 if constexpr (std::is_same_v<uint32_t, T>)
154 if constexpr (std::is_same_v<uint64_t, T>)
156 if constexpr (std::is_same_v<int8_t, T>)
158 if constexpr (std::is_same_v<uint8_t, T>)
160 throw std::runtime_error(
"FIXME");
164template<
typename T>
bool NPArray<T>::checkType(PyArrayObject *py)
166 if (PyArray_TYPE(py) == classTypenum())
168 if constexpr (std::is_same_v<uint8_t, T>)
169 return PyArray_TYPE(py) == NPY_BOOL ||
170 PyArray_TYPE(py) == NPY_BYTE ||
171 PyArray_TYPE(py) == NPY_UINT8;
175template<
typename T>
bool NPArray<T>::isPacked()
const
178 for (
int d = degree() - 1; d >= 0; --d) {
188template<
typename T>
void NPArray<T>::copyto(std::vector<T> &v)
const
192 std::memcpy(v.data(), data(), v.size() *
sizeof(T));
Definition PyArray.hpp:11