36#ifndef VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
37#define VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
39#ifndef NPY_NO_DEPRECATED_API
40# define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
44#include "array_vector.hxx"
45#include "python_utility.hxx"
46#include "axistags.hxx"
53python_ptr getArrayTypeObject()
55 python_ptr arraytype((PyObject*)&PyArray_Type);
56 python_ptr vigra(PyImport_ImportModule(
"vigra"));
59 return pythonGetAttr(vigra,
"standardArrayType", arraytype);
63std::string defaultOrder(std::string defaultValue =
"C")
65 python_ptr arraytype = getArrayTypeObject();
66 return pythonGetAttr(arraytype,
"defaultOrder", defaultValue);
70python_ptr defaultAxistags(
int ndim, std::string order =
"")
73 order = defaultOrder();
74 python_ptr arraytype = getArrayTypeObject();
75 python_ptr func(pythonFromData(
"defaultAxistags"));
76 python_ptr d(pythonFromData(ndim));
77 python_ptr o(pythonFromData(order));
78 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), o.get(), NULL),
79 python_ptr::keep_count);
87python_ptr emptyAxistags(
int ndim)
89 python_ptr arraytype = getArrayTypeObject();
90 python_ptr func(pythonFromData(
"_empty_axistags"));
91 python_ptr d(pythonFromData(ndim));
92 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), NULL),
93 python_ptr::keep_count);
102getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
103 python_ptr
object,
const char * name,
104 AxisInfo::AxisType type,
bool ignoreErrors)
106 python_ptr func(pythonFromData(name));
107 python_ptr t(pythonFromData((
long)type));
108 python_ptr permutation(PyObject_CallMethodObjArgs(
object, func.get(), t.get(), NULL),
109 python_ptr::keep_count);
110 if(!permutation && ignoreErrors)
115 pythonToCppException(permutation);
117 if(!PySequence_Check(permutation))
121 std::string message = std::string(name) +
"() did not return a sequence.";
122 PyErr_SetString(PyExc_ValueError, message.c_str());
123 pythonToCppException(
false);
126 ArrayVector<npy_intp> res(PySequence_Length(permutation));
127 for(
int k=0; k<(int)res.size(); ++k)
129 python_ptr i(PySequence_GetItem(permutation, k), python_ptr::keep_count);
130#if PY_MAJOR_VERSION < 3
133 if (!PyLong_Check(i))
138 std::string message = std::string(name) +
"() did not return a sequence of int.";
139 PyErr_SetString(PyExc_ValueError, message.c_str());
140 pythonToCppException(
false);
142#if PY_MAJOR_VERSION < 3
143 res[k] = PyInt_AsLong(i);
145 res[k] = PyLong_AsLong(i);
153getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
154 python_ptr
object,
const char * name,
bool ignoreErrors)
156 getAxisPermutationImpl(permute,
object, name, AxisInfo::AllAxes, ignoreErrors);
175 typedef PyObject * pointer;
179 PyAxisTags(python_ptr tags = python_ptr(),
bool createCopy =
false)
184 if(!PySequence_Check(tags))
186 PyErr_SetString(PyExc_TypeError,
187 "PyAxisTags(tags): tags argument must have type 'AxisTags'.");
188 pythonToCppException(
false);
190 else if(PySequence_Length(tags) == 0)
197 python_ptr func(pythonFromData(
"__copy__"));
198 axistags = python_ptr(PyObject_CallMethodObjArgs(tags, func.get(), NULL),
199 python_ptr::keep_count);
207 PyAxisTags& operator=(PyAxisTags
const & other) =
default;
208 PyAxisTags(PyAxisTags
const & other,
bool createCopy =
false)
214 python_ptr func(pythonFromData(
"__copy__"));
215 axistags = python_ptr(PyObject_CallMethodObjArgs(other.axistags, func.get(), NULL),
216 python_ptr::keep_count);
220 axistags = other.axistags;
224 PyAxisTags(
int ndim, std::string
const & order =
"")
227 axistags = detail::defaultAxistags(ndim, order);
229 axistags = detail::emptyAxistags(ndim);
235 ? PySequence_Length(axistags)
239 long channelIndex(
long defaultVal)
const
241 return pythonGetAttr(axistags,
"channelIndex", defaultVal);
244 long channelIndex()
const
246 return channelIndex(size());
249 bool hasChannelAxis()
const
251 return channelIndex() != size();
254 long innerNonchannelIndex(
long defaultVal)
const
256 return pythonGetAttr(axistags,
"innerNonchannelIndex", defaultVal);
259 long innerNonchannelIndex()
const
261 return innerNonchannelIndex(size());
264 void setChannelDescription(std::string
const & description)
268 python_ptr d(pythonFromData(description));
269 python_ptr func(pythonFromData(
"setChannelDescription"));
270 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), d.get(), NULL),
271 python_ptr::keep_count);
272 pythonToCppException(res);
275 double resolution(
long index)
279 python_ptr func(pythonFromData(
"resolution"));
280 python_ptr i(pythonFromData(index));
281 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), NULL),
282 python_ptr::keep_count);
283 pythonToCppException(res);
284 if(!PyFloat_Check(res))
286 PyErr_SetString(PyExc_TypeError,
"AxisTags.resolution() did not return float.");
287 pythonToCppException(
false);
289 return PyFloat_AsDouble(res);
292 void setResolution(
long index,
double resolution)
296 python_ptr func(pythonFromData(
"setResolution"));
297 python_ptr i(pythonFromData(index));
298 python_ptr r(PyFloat_FromDouble(resolution), python_ptr::keep_count);
299 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), r.get(), NULL),
300 python_ptr::keep_count);
301 pythonToCppException(res);
304 void scaleResolution(
long index,
double factor)
308 python_ptr func(pythonFromData(
"scaleResolution"));
309 python_ptr i(pythonFromData(index));
310 python_ptr f(PyFloat_FromDouble(factor), python_ptr::keep_count);
311 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), f.get(), NULL),
312 python_ptr::keep_count);
313 pythonToCppException(res);
316 void toFrequencyDomain(
long index,
int size,
int sign = 1)
320 python_ptr func(
sign == 1
321 ? pythonFromData(
"toFrequencyDomain")
322 : pythonFromData(
"fromFrequencyDomain"));
323 python_ptr i(pythonFromData(index));
324 python_ptr s(pythonFromData(size));
325 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), s.get(), NULL),
326 python_ptr::keep_count);
327 pythonToCppException(res);
330 void fromFrequencyDomain(
long index,
int size)
332 toFrequencyDomain(index, size, -1);
335 ArrayVector<npy_intp>
336 permutationToNormalOrder(
bool ignoreErrors =
false)
const
338 ArrayVector<npy_intp> permute;
339 detail::getAxisPermutationImpl(permute, axistags,
"permutationToNormalOrder", ignoreErrors);
343 ArrayVector<npy_intp>
344 permutationToNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const
346 ArrayVector<npy_intp> permute;
347 detail::getAxisPermutationImpl(permute, axistags,
348 "permutationToNormalOrder", types, ignoreErrors);
352 ArrayVector<npy_intp>
353 permutationFromNormalOrder(
bool ignoreErrors =
false)
const
355 ArrayVector<npy_intp> permute;
356 detail::getAxisPermutationImpl(permute, axistags,
357 "permutationFromNormalOrder", ignoreErrors);
361 ArrayVector<npy_intp>
362 permutationFromNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const
364 ArrayVector<npy_intp> permute;
365 detail::getAxisPermutationImpl(permute, axistags,
366 "permutationFromNormalOrder", types, ignoreErrors);
370 void dropChannelAxis()
374 python_ptr func(pythonFromData(
"dropChannelAxis"));
375 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
376 python_ptr::keep_count);
377 pythonToCppException(res);
380 void insertChannelAxis()
384 python_ptr func(pythonFromData(
"insertChannelAxis"));
385 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
386 python_ptr::keep_count);
387 pythonToCppException(res);
392 return axistags.get();
395 bool operator!()
const
410 enum ChannelAxis { first, last, none };
412 ArrayVector<npy_intp> shape, original_shape;
414 ChannelAxis channelAxis;
415 std::string channelDescription;
423 template <
class U,
int N>
424 TaggedShape(TinyVector<U, N>
const & sh, PyAxisTags tags)
425 : shape(sh.begin(), sh.end()),
426 original_shape(sh.begin(), sh.end()),
432 TaggedShape(ArrayVector<T>
const & sh, PyAxisTags tags)
433 : shape(sh.begin(), sh.end()),
434 original_shape(sh.begin(), sh.end()),
439 template <
class U,
int N>
440 explicit TaggedShape(TinyVector<U, N>
const & sh)
441 : shape(sh.begin(), sh.end()),
442 original_shape(sh.begin(), sh.end()),
447 explicit TaggedShape(ArrayVector<T>
const & sh)
448 : shape(sh.begin(), sh.end()),
449 original_shape(sh.begin(), sh.end()),
453 template <
class U,
int N>
454 TaggedShape & resize(TinyVector<U, N>
const & sh)
456 int start = channelAxis == first
459 stop = channelAxis == last
463 vigra_precondition(N == stop - start || size() == 0,
464 "TaggedShape.resize(): size mismatch.");
469 for(
int k=0; k<N; ++k)
470 shape[k+start] = sh[k];
477 return resize(TinyVector<MultiArrayIndex, 1>(v1));
482 return resize(TinyVector<MultiArrayIndex, 2>(v1, v2));
487 return resize(TinyVector<MultiArrayIndex, 3>(v1, v2, v3));
493 return resize(TinyVector<MultiArrayIndex, 4>(v1, v2, v3, v4));
496 npy_intp & operator[](
int i)
501 npy_intp operator[](
int i)
const
506 unsigned int size()
const
511 TaggedShape & operator+=(
int v)
513 int start = channelAxis == first
516 stop = channelAxis == last
519 for(
int k=start; k<stop; ++k)
525 TaggedShape & operator-=(
int v)
527 return operator+=(-v);
530 TaggedShape & operator*=(
int factor)
532 int start = channelAxis == first
535 stop = channelAxis == last
538 for(
int k=start; k<stop; ++k)
544 void rotateToNormalOrder()
546 if(axistags && channelAxis == last)
548 int ndim = (int)size();
550 npy_intp channelCount = shape[ndim-1];
551 for(
int k=ndim-1; k>0; --k)
552 shape[k] = shape[k-1];
553 shape[0] = channelCount;
555 channelCount = original_shape[ndim-1];
556 for(
int k=ndim-1; k>0; --k)
557 original_shape[k] = original_shape[k-1];
558 original_shape[0] = channelCount;
564 TaggedShape & setChannelDescription(std::string
const & description)
568 channelDescription = description;
572 TaggedShape & setChannelIndexLast()
580 template <
class U,
int N>
581 TaggedShape & transposeShape(TinyVector<U, N>
const & p)
585 int ntags = axistags.size();
586 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
588 int tstart = (axistags.channelIndex(ntags) < ntags)
591 int sstart = (channelAxis == first)
594 int ndim = ntags - tstart;
596 vigra_precondition(N == ndim,
597 "TaggedShape.transposeShape(): size mismatch.");
599 PyAxisTags newAxistags(axistags.axistags);
600 for(
int k=0; k<ndim; ++k)
602 original_shape[k+sstart] = shape[p[k]+sstart];
603 newAxistags.setResolution(permute[k+tstart], axistags.resolution(permute[p[k]+tstart]));
605 axistags = newAxistags;
609 for(
int k=0; k<N; ++k)
611 original_shape[k] = shape[p[k]];
614 shape = original_shape;
619 TaggedShape & toFrequencyDomain(
int sign = 1)
623 int ntags = axistags.size();
625 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
627 int tstart = (axistags.channelIndex(ntags) < ntags)
630 int sstart = (channelAxis == first)
633 int send = (channelAxis == last)
636 int size = send - sstart;
638 for(
int k=0; k<size; ++k)
640 axistags.toFrequencyDomain(permute[k+tstart], shape[k+sstart],
sign);
646 bool hasChannelAxis()
const
648 return channelAxis !=none;
651 TaggedShape & fromFrequencyDomain()
653 return toFrequencyDomain(-1);
656 bool compatible(TaggedShape
const & other)
const
658 if(channelCount() != other.channelCount())
661 int start = channelAxis == first
664 stop = channelAxis == last
667 int ostart = other.channelAxis == first
670 ostop = other.channelAxis == last
671 ? (int)other.size()-1
674 int len = stop - start;
675 if(len != ostop - ostart)
678 for(
int k=0; k<len; ++k)
679 if(shape[k+start] != other.shape[k+ostart])
684 TaggedShape & setChannelCount(
int count)
695 shape.erase(shape.begin());
696 original_shape.erase(original_shape.begin());
703 shape[size()-1] = count;
708 original_shape.pop_back();
715 shape.push_back(count);
716 original_shape.push_back(count);
724 int channelCount()
const
731 return shape[size()-1];
739void scaleAxisResolution(TaggedShape & tagged_shape)
741 if(tagged_shape.size() != tagged_shape.original_shape.size())
744 int ntags = tagged_shape.axistags.size();
748 int tstart = (tagged_shape.axistags.channelIndex(ntags) < ntags)
751 int sstart = (tagged_shape.channelAxis == TaggedShape::first)
754 int size = (int)tagged_shape.size() - sstart;
756 for(
int k=0; k<size; ++k)
759 if(tagged_shape.shape[sk] == tagged_shape.original_shape[sk])
761 double factor = (tagged_shape.original_shape[sk] - 1.0) / (tagged_shape.shape[sk] - 1.0);
762 tagged_shape.axistags.scaleResolution(permute[k+tstart], factor);
767void unifyTaggedShapeSize(TaggedShape & tagged_shape)
769 PyAxisTags axistags = tagged_shape.axistags;
772 int ndim = (int)shape.size();
773 int ntags = axistags.size();
775 long channelIndex = axistags.channelIndex();
777 if(tagged_shape.channelAxis == TaggedShape::none)
780 if(channelIndex == ntags)
784 vigra_precondition(ndim == ntags,
785 "constructArray(): size mismatch between shape and axistags.");
795 axistags.dropChannelAxis();
799 vigra_precondition(ndim == ntags,
800 "constructArray(): size mismatch between shape and axistags.");
807 if(channelIndex == ntags)
811 vigra_precondition(ndim == ntags+1,
812 "constructArray(): size mismatch between shape and axistags.");
818 shape.erase(shape.begin());
825 axistags.insertChannelAxis();
832 vigra_precondition(ndim == ntags,
833 "constructArray(): size mismatch between shape and axistags.");
841 if(tagged_shape.axistags)
843 tagged_shape.rotateToNormalOrder();
847 scaleAxisResolution(tagged_shape);
851 unifyTaggedShapeSize(tagged_shape);
853 if(tagged_shape.channelDescription !=
"")
854 tagged_shape.axistags.setChannelDescription(tagged_shape.channelDescription);
856 return tagged_shape.shape;
Definition array_vector.hxx:514
T sign(T t)
The sign function.
Definition mathutil.hxx:591
std::ptrdiff_t MultiArrayIndex
Definition multi_fwd.hxx:60