PCB Environment 2
Loading...
Searching...
No Matches
PyArray.hpp
1
2#ifndef GYM_PCB_PYARRAY_H
3#define GYM_PCB_PYARRAY_H
4
5#include "Py.hpp"
6#include <vector>
7
8namespace py {
9
10template<typename T> class NPArray
11{
12public:
13 static NPArray<T> *create(PyObject *);
14 NPArray() : mPy(0) { }
15 NPArray<T>& operator=(const NPArray<T> &avoid) { assert(!avoid.mPy); mPy = 0; return *this; }
16 NPArray<T>& operator=(NPArray<T> &&move) { own(move.mPy); move.mPy = 0; return *this; }
17 NPArray(const NPArray<T> &avoid) : mPy(0) { assert(!avoid.mPy); }
18 NPArray(NPArray<T> &&move) : mPy(move.mPy) { move.mPy = 0; }
19 NPArray(const std::vector<T>&);
20 NPArray(const std::vector<T>&, npy_intp d1, npy_intp d2);
21 NPArray(T *, npy_intp const *dims, uint num_dims);
22 NPArray(T *, npy_intp d1);
23 NPArray(T *, npy_intp d1, npy_intp d2);
24 NPArray(T *, npy_intp d1, npy_intp d2, npy_intp d3);
25 NPArray(T *, npy_intp d1, npy_intp d2, npy_intp d3, npy_intp d4);
26 NPArray(PyObject *py) : mPy(0) { own(py); Assert(); }
27 bool own(PyObject *);
28 bool own(PyArrayObject *);
29 NPArray<T>& ownData();
30 ~NPArray() { own((PyArrayObject *)0); }
31 bool valid() const { return mPy; }
32 bool isPacked() const;
33 void Assert() const { if (!valid()) throw std::runtime_error("Invalid numpy array"); }
34 PyObject *release() { auto rv = py(); mPy = 0; return rv; }
35 PyObject *py() const { return reinterpret_cast<PyObject *>(mPy); }
36 PyArrayObject *pyArray() const { return mPy; }
37 T *data() const { return static_cast<T *>(PyArray_DATA(mPy)); }
38 intptr_t count() const { return PyArray_SIZE(mPy); }
39 intptr_t stride(uint d) const { return PyArray_STRIDE(mPy, d); }
40 intptr_t size() const { return PyArray_NBYTES(mPy); }
41 uint dim(uint d) const { return PyArray_DIM(mPy, d); }
42 uint degree() const { return PyArray_NDIM(mPy); }
43 T operator()(uint x) const { return *(T *)PyArray_GETPTR1(mPy, x); }
44 T operator()(uint x, uint y) const { return *(T *)PyArray_GETPTR2(mPy, x, y); }
45 T operator()(uint x, uint y, uint z) const { return *(T *)PyArray_GETPTR3(mPy, x, y, z); }
46 T operator()(uint x, uint y, uint z, uint i) const { return *(T *)PyArray_GETPTR4(mPy, x, y, z, i); }
47 void copyto(std::vector<T>&) const;
48 std::vector<T> vector() const { std::vector<T> v; copyto(v); return v; }
49 static void checkAPI() { assert(PY_ARRAY_UNIQUE_SYMBOL); }
50 static bool checkType(PyArrayObject *);
51 static int classTypenum();
52private:
53 PyArrayObject *mPy;
54};
55
56template<typename T> NPArray<T>::NPArray(T *data, npy_intp const *dims, uint num_dims)
57{
58 checkAPI();
59 mPy = reinterpret_cast<PyArrayObject *>(PyArray_SimpleNewFromData(num_dims, dims, classTypenum(), data));
60}
61template<typename T> NPArray<T>::NPArray(const std::vector<T> &v)
62{
63 const auto size = v.size() * sizeof(T);
64 const npy_intp dims[1] = { (uint)v.size() };
65 checkAPI();
66 auto data = (T *)malloc(size);
67 std::memcpy(data, v.data(), size);
68 mPy = reinterpret_cast<PyArrayObject *>(PyArray_SimpleNewFromData(1, dims, classTypenum(), data));
69 ownData();
70}
71template<typename T> NPArray<T>::NPArray(const std::vector<T> &v, npy_intp d1, npy_intp d2)
72{
73 auto srcSize = v.size() * sizeof(T);
74 auto dstSize = d1 * d2 * sizeof(T);
75 assert(srcSize == dstSize);
76 auto data = (T *)malloc(dstSize);
77 std::memcpy(data, v.data(), std::min(srcSize, dstSize));
78 const npy_intp dims[2] = { d1, d2 };
79 checkAPI();
80 mPy = reinterpret_cast<PyArrayObject *>(PyArray_SimpleNewFromData(2, dims, classTypenum(), data));
81 ownData();
82}
83template<typename T> NPArray<T>::NPArray(T *data, npy_intp d1)
84{
85 checkAPI();
86 mPy = reinterpret_cast<PyArrayObject *>(PyArray_SimpleNewFromData(1, &d1, classTypenum(), data));
87}
88template<typename T> NPArray<T>::NPArray(T *data, npy_intp d1, npy_intp d2)
89{
90 const npy_intp dims[2] = { d1, d2 };
91 checkAPI();
92 mPy = reinterpret_cast<PyArrayObject *>(PyArray_SimpleNewFromData(2, dims, classTypenum(), data));
93}
94template<typename T> NPArray<T>::NPArray(T *data, npy_intp d1, npy_intp d2, npy_intp d3)
95{
96 const npy_intp dims[3] = { d1, d2, d3 };
97 checkAPI();
98 mPy = reinterpret_cast<PyArrayObject *>(PyArray_SimpleNewFromData(3, dims, classTypenum(), data));
99}
100template<typename T> NPArray<T>::NPArray(T *data, npy_intp d1, npy_intp d2, npy_intp d3, npy_intp d4)
101{
102 const npy_intp dims[4] = { d1, d2, d3, d4 };
103 checkAPI();
104 mPy = reinterpret_cast<PyArrayObject *>(PyArray_SimpleNewFromData(4, dims, classTypenum(), data));
105}
106
107inline void _capsuleDtor(PyObject *capsule)
108{
109 std::free(PyCapsule_GetPointer(capsule, 0));
110}
111template<typename T> NPArray<T>& NPArray<T>::ownData()
112{
113 if (!PyArray_BASE(mPy))
114 PyArray_SetBaseObject(mPy, PyCapsule_New(data(), 0, _capsuleDtor));
115 return *this;
116}
117
118template<typename T> NPArray<T> *NPArray<T>::create(PyObject *py)
119{
120 if (py && PyArray_Check(py) && checkType(reinterpret_cast<PyArrayObject *>(py)))
121 return new NPArray(py);
122 return 0;
123}
124
125template<typename T> bool NPArray<T>::own(PyArrayObject *py)
126{
127 checkAPI();
128 if (mPy)
129 Py_DECREF(mPy);
130 mPy = (py && checkType(py)) ? py : 0;
131 return mPy;
132}
133template<typename T> bool NPArray<T>::own(PyObject *py)
134{
135 checkAPI();
136 if (mPy)
137 Py_DECREF(mPy);
138 mPy = (py && PyArray_Check(py)) ? reinterpret_cast<PyArrayObject *>(py) : 0;
139 if (mPy && !checkType(mPy))
140 mPy = 0;
141 return mPy;
142}
143
144template<typename T> int NPArray<T>::classTypenum()
145{
146 if constexpr (std::is_same_v<float, T>)
147 return NPY_FLOAT32;
148 if constexpr (std::is_same_v<double, T>)
149 return NPY_FLOAT64;
150 if constexpr (std::is_same_v<int32_t, T>)
151 return NPY_INT32;
152 if constexpr (std::is_same_v<uint32_t, T>)
153 return NPY_UINT32;
154 if constexpr (std::is_same_v<uint64_t, T>)
155 return NPY_UINT64;
156 if constexpr (std::is_same_v<int8_t, T>)
157 return NPY_INT8;
158 if constexpr (std::is_same_v<uint8_t, T>)
159 return NPY_UBYTE;
160 throw std::runtime_error("FIXME");
161 return NPY_VOID;
162}
163
164template<typename T> bool NPArray<T>::checkType(PyArrayObject *py)
165{
166 if (PyArray_TYPE(py) == classTypenum())
167 return true;
168 if constexpr (std::is_same_v<uint8_t, T>)
169 return PyArray_TYPE(py) == NPY_BOOL ||
170 PyArray_TYPE(py) == NPY_BYTE ||
171 PyArray_TYPE(py) == NPY_UINT8;
172 return false;
173}
174
175template<typename T> bool NPArray<T>::isPacked() const
176{
177 uint sz = sizeof(T);
178 for (int d = degree() - 1; d >= 0; --d) {
179 if (sz != stride(d))
180 return false;
181 sz *= dim(d);
182 }
183 if (size() != sz)
184 return false;
185 return true;
186}
187
188template<typename T> void NPArray<T>::copyto(std::vector<T> &v) const
189{
190 assert(isPacked());
191 v.resize(count());
192 std::memcpy(v.data(), data(), v.size() * sizeof(T));
193}
194
195} // namespace py
196
197using FloatNPArray = py::NPArray<float>;
198
199#endif // GYM_PCB_PYARRAY_H
Definition PyArray.hpp:11