Skip to content

Commit 54055d4

Browse files
committed
fixes to get promote_affine, zero_expand to work in hlsl
1 parent 0448a2c commit 54055d4

File tree

1 file changed

+53
-35
lines changed

1 file changed

+53
-35
lines changed

include/nbl/builtin/hlsl/math/linalg/transform.hlsl

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
#ifndef _NBL_BUILTIN_HLSL_MATH_LINALG_TRANSFORM_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_MATH_LINALG_TRANSFORM_INCLUDED_
66

7-
87
#include <nbl/builtin/hlsl/mpl.hlsl>
98
#include <nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl>
109
#include <nbl/builtin/hlsl/concepts.hlsl>
1110

12-
1311
namespace nbl
1412
{
1513
namespace hlsl
@@ -26,50 +24,70 @@ namespace linalg
2624
///
2725
/// @tparam T A floating-point scalar type
2826
template <typename T>
29-
matrix<T, 3, 3> rotation_mat(T angle, vector<T, 3> const& axis)
27+
matrix<T, 3, 3> rotation_mat(T angle, const vector<T, 3> axis)
3028
{
31-
T const a = angle;
32-
T const c = cos(a);
33-
T const s = sin(a);
29+
const T a = angle;
30+
const T c = cos(a);
31+
const T s = sin(a);
32+
33+
vector<T, 3> temp = hlsl::promote<vector<T, 3> >((T(1.0) - c) * axis);
3434

35-
vector<T, 3> temp((T(1) - c) * axis);
35+
matrix<T, 3, 3> rotation;
36+
rotation[0][0] = c + temp[0] * axis[0];
37+
rotation[0][1] = temp[1] * axis[0] - s * axis[2];
38+
rotation[0][2] = temp[2] * axis[0] + s * axis[1];
3639

37-
matrix<T, 3, 3> rotation;
38-
rotation[0][0] = c + temp[0] * axis[0];
39-
rotation[0][1] = temp[1] * axis[0] - s * axis[2];
40-
rotation[0][2] = temp[2] * axis[0] + s * axis[1];
40+
rotation[1][0] = temp[0] * axis[1] + s * axis[2];
41+
rotation[1][1] = c + temp[1] * axis[1];
42+
rotation[1][2] = temp[2] * axis[1] - s * axis[0];
4143

42-
rotation[1][0] = temp[0] * axis[1] + s * axis[2];
43-
rotation[1][1] = c + temp[1] * axis[1];
44-
rotation[1][2] = temp[2] * axis[1] - s * axis[0];
44+
rotation[2][0] = temp[0] * axis[2] - s * axis[1];
45+
rotation[2][1] = temp[1] * axis[2] + s * axis[0];
46+
rotation[2][2] = c + temp[2] * axis[2];
47+
48+
return rotation;
49+
}
4550

46-
rotation[2][0] = temp[0] * axis[2] - s * axis[1];
47-
rotation[2][1] = temp[1] * axis[2] + s * axis[0];
48-
rotation[2][2] = c + temp[2] * axis[2];
51+
namespace impl
52+
{
53+
template<uint16_t MOut, uint16_t MIn, typename T>
54+
struct zero_expand_helper
55+
{
56+
static vector<T, MOut> __call(vector<T, MIn> inVec)
57+
{
58+
return vector<T, MOut>(inVec, vector<T, MOut - MIn>(0));
59+
}
60+
};
61+
template<uint16_t M, typename T>
62+
struct zero_expand_helper<M,M,T>
63+
{
64+
static vector<T, M> __call(vector<T, M> inVec)
65+
{
66+
return inVec;
67+
}
68+
};
69+
}
4970

50-
return rotation;
71+
template<uint16_t MOut, uint16_t MIn, typename T NBL_FUNC_REQUIRES(MOut >= MIn)
72+
vector<T, MOut> zero_expand(vector<T, MIn> inVec)
73+
{
74+
return impl::zero_expand_helper<MOut, MIn, T>::__call(inVec);
5175
}
5276

53-
template <uint16_t NOut, uint16_t MOut, uint16_t NIn, uint16_t MIn, typename T>
54-
requires(NOut >= NIn && MOut >= MIn)
55-
matrix <T, NOut, MOut> promote_affine(const matrix<T, NIn, MIn> inMatrix)
77+
template <uint16_t NOut, uint16_t MOut, uint16_t NIn, uint16_t MIn, typename T NBL_FUNC_REQUIRES(NOut >= NIn && MOut >= MIn)
78+
matrix<T, NOut, MOut> promote_affine(const matrix<T, NIn, MIn> inMatrix)
5679
{
57-
matrix<T, NOut, MOut> retval;
80+
matrix<T, NOut, MOut> retval;
5881

59-
using out_row_t = hlsl::vector<T, MOut>;
60-
auto expandVec = [](vector<T, MIn> inVec) -> vector<T, MOut>
61-
{
62-
if constexpr (MIn == MOut) return inVec;
63-
return vector<T, MOut>(inVec, vector<T, MOut - MIn>(0));
64-
};
82+
using out_row_t = hlsl::vector<T, MOut>;
6583

66-
for (auto row_i = 0u; row_i < NOut; row_i++)
67-
{
68-
retval[row_i] = row_i < NIn ? expandVec(inMatrix[row_i]) : promote<out_row_t>(0.0f);
69-
if ((row_i >= NIn || row_i >= MIn) && row_i < MOut) retval[row_i][row_i] = T(1);
70-
71-
}
72-
return retval;
84+
for (uint32_t row_i = 0; row_i < NOut; row_i++)
85+
{
86+
retval[row_i] = hlsl::mix(promote<out_row_t>(0.0), zero_expand<MOut, MIn>(inMatrix[row_i]), row_i < NIn);
87+
if ((row_i >= NIn || row_i >= MIn) && row_i < MOut)
88+
retval[row_i][row_i] = T(1.0);
89+
}
90+
return retval;
7391
}
7492

7593
}

0 commit comments

Comments
 (0)