Skip to content

Commit 6f37731

Browse files
authored
Merge pull request #2393 from DerThorsten/fix/compare_shape_of_different_types
changed impl computed_assign st shapes of different types are supported
2 parents f3c11b2 + 74d5650 commit 6f37731

File tree

3 files changed

+157
-3
lines changed

3 files changed

+157
-3
lines changed

include/xtensor/xassign.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <algorithm>
1414
#include <type_traits>
1515
#include <utility>
16+
#include <functional>
1617

1718
#include <xtl/xcomplex.hpp>
1819
#include <xtl/xsequence.hpp>
@@ -418,16 +419,20 @@ namespace xt
418419
inline void xexpression_assigner<Tag>::computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2)
419420
{
420421
using shape_type = typename E1::shape_type;
422+
using comperator_type = std::greater<typename shape_type::value_type>;
423+
421424
using size_type = typename E1::size_type;
422425

423426
E1& de1 = e1.derived_cast();
424427
const E2& de2 = e2.derived_cast();
425428

426-
size_type dim = de2.dimension();
427-
shape_type shape = uninitialized_shape<shape_type>(dim);
429+
size_type dim2 = de2.dimension();
430+
shape_type shape = uninitialized_shape<shape_type>(dim2);
431+
428432
bool trivial_broadcast = de2.broadcast_shape(shape, true);
429433

430-
if (dim > de1.dimension() || shape > de1.shape())
434+
auto && de1_shape = de1.shape();
435+
if (dim2 > de1.dimension() || std::lexicographical_compare(shape.begin(), shape.end(), de1_shape.begin(), de1_shape.end(), comperator_type()))
431436
{
432437
typename E1::temporary_type tmp(shape);
433438
base_type::assign_data(tmp, e2, trivial_broadcast);

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ set(XTENSOR_TESTS
202202
main.cpp
203203
test_xaccumulator.cpp
204204
test_xadapt.cpp
205+
test_xassign.cpp
205206
test_xaxis_iterator.cpp
206207
test_xaxis_slice_iterator.cpp
207208
test_xbuffer_adaptor.cpp

test/test_xassign.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/***************************************************************************
2+
* Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#include "gtest/gtest.h"
11+
#include "xtensor/xarray.hpp"
12+
#include "xtensor/xtensor.hpp"
13+
14+
#include "xtensor/xassign.hpp"
15+
#include "xtensor/xnoalias.hpp"
16+
#include "test_common.hpp"
17+
18+
#include <type_traits>
19+
#include <vector>
20+
21+
22+
// a dummy shape *not derived* from std::vector but compatible
23+
template<class T>
24+
class my_vector
25+
{
26+
private:
27+
using vector_type = std::vector<T>;
28+
public:
29+
using value_type = T;
30+
using size_type = typename vector_type::size_type;
31+
template<class U>
32+
my_vector(std::initializer_list<U> vals)
33+
: m_data(vals.begin(), vals.end())
34+
{
35+
}
36+
my_vector(const std::size_t size = 0, const T & val = T())
37+
: m_data(size, val)
38+
{
39+
}
40+
auto resize(const std::size_t size)
41+
{
42+
return m_data.resize(size);
43+
}
44+
auto size()const
45+
{
46+
return m_data.size();
47+
}
48+
auto cend()const
49+
{
50+
return m_data.cend();
51+
}
52+
auto cbegin()const
53+
{
54+
return m_data.cbegin();
55+
}
56+
auto end()
57+
{
58+
return m_data.end();
59+
}
60+
auto end()const
61+
{
62+
return m_data.end();
63+
}
64+
auto begin()
65+
{
66+
return m_data.begin();
67+
}
68+
auto begin()const
69+
{
70+
return m_data.begin();
71+
}
72+
auto empty()const
73+
{
74+
return m_data.empty();
75+
}
76+
auto & back()
77+
{
78+
return m_data.back();
79+
}
80+
const auto & back()const
81+
{
82+
return m_data.back();
83+
}
84+
auto & front()
85+
{
86+
return m_data.front();
87+
}
88+
const auto & front()const
89+
{
90+
return m_data.front();
91+
}
92+
auto & operator[](const std::size_t i)
93+
{
94+
return m_data[i];
95+
}
96+
const auto & operator[](const std::size_t i)const
97+
{
98+
return m_data[i];
99+
}
100+
private:
101+
std::vector<T> m_data;
102+
};
103+
104+
105+
namespace xt
106+
{
107+
108+
template <class T, class C_T>
109+
struct rebind_container<T, my_vector<C_T>>
110+
{
111+
using type = my_vector<T>;
112+
};
113+
114+
TEST(xassign, mix_shape_types)
115+
{
116+
{
117+
// xarray like with custom shape
118+
using my_xarray = xt::xarray_container<
119+
std::vector<int>,
120+
xt::layout_type::row_major,
121+
my_vector<std::size_t>
122+
>;
123+
124+
auto a = my_xarray::from_shape({1,3});
125+
auto b = xt::xtensor<int,2>::from_shape({2,3});
126+
xt::noalias(a) += b;
127+
EXPECT_EQ(a.dimension(), 2);
128+
EXPECT_EQ(a.shape(0), 2);
129+
EXPECT_EQ(a.shape(1), 3);
130+
}
131+
{
132+
// xarray like with custom shape
133+
using my_xarray = xt::xarray_container<
134+
std::vector<int>,
135+
xt::layout_type::row_major,
136+
my_vector<std::size_t>
137+
>;
138+
139+
auto a = my_xarray::from_shape({3});
140+
auto b = xt::xtensor<int,2>::from_shape({2,3});
141+
xt::noalias(a) += b;
142+
EXPECT_EQ(a.dimension(), 2);
143+
EXPECT_EQ(a.shape(0), 2);
144+
EXPECT_EQ(a.shape(1), 3);
145+
}
146+
147+
}
148+
}

0 commit comments

Comments
 (0)