4
4
#ifndef _NBL_BUILTIN_HLSL_MATH_LINALG_TRANSFORM_INCLUDED_
5
5
#define _NBL_BUILTIN_HLSL_MATH_LINALG_TRANSFORM_INCLUDED_
6
6
7
-
8
7
#include <nbl/builtin/hlsl/mpl.hlsl>
9
8
#include <nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl>
10
9
#include <nbl/builtin/hlsl/concepts.hlsl>
11
10
12
-
13
11
namespace nbl
14
12
{
15
13
namespace hlsl
@@ -26,50 +24,70 @@ namespace linalg
26
24
///
27
25
/// @tparam T A floating-point scalar type
28
26
template <typename T>
29
- matrix <T, 3 , 3 > rotation_mat (T angle, vector <T, 3 > const & axis)
27
+ matrix <T, 3 , 3 > rotation_mat (T angle, const vector <T, 3 > axis)
30
28
{
31
- T const a = angle;
32
- T const c = cos (a);
33
- T const s = sin (a);
29
+ const T a = angle;
30
+ const T c = cos (a);
31
+ const T s = sin (a);
32
+
33
+ vector <T, 3 > temp = hlsl::promote<vector <T, 3 > >((T (1.0 ) - c) * axis);
34
34
35
- vector <T, 3 > temp ((T (1 ) - c) * axis);
35
+ matrix <T, 3 , 3 > rotation;
36
+ rotation[0 ][0 ] = c + temp[0 ] * axis[0 ];
37
+ rotation[0 ][1 ] = temp[1 ] * axis[0 ] - s * axis[2 ];
38
+ rotation[0 ][2 ] = temp[2 ] * axis[0 ] + s * axis[1 ];
36
39
37
- matrix <T, 3 , 3 > rotation;
38
- rotation[0 ][0 ] = c + temp[0 ] * axis[0 ];
39
- rotation[0 ][1 ] = temp[1 ] * axis[0 ] - s * axis[2 ];
40
- rotation[0 ][2 ] = temp[2 ] * axis[0 ] + s * axis[1 ];
40
+ rotation[1 ][0 ] = temp[0 ] * axis[1 ] + s * axis[2 ];
41
+ rotation[1 ][1 ] = c + temp[1 ] * axis[1 ];
42
+ rotation[1 ][2 ] = temp[2 ] * axis[1 ] - s * axis[0 ];
41
43
42
- rotation[1 ][0 ] = temp[0 ] * axis[1 ] + s * axis[2 ];
43
- rotation[1 ][1 ] = c + temp[1 ] * axis[1 ];
44
- rotation[1 ][2 ] = temp[2 ] * axis[1 ] - s * axis[0 ];
44
+ rotation[2 ][0 ] = temp[0 ] * axis[2 ] - s * axis[1 ];
45
+ rotation[2 ][1 ] = temp[1 ] * axis[2 ] + s * axis[0 ];
46
+ rotation[2 ][2 ] = c + temp[2 ] * axis[2 ];
47
+
48
+ return rotation;
49
+ }
45
50
46
- rotation[2 ][0 ] = temp[0 ] * axis[2 ] - s * axis[1 ];
47
- rotation[2 ][1 ] = temp[1 ] * axis[2 ] + s * axis[0 ];
48
- rotation[2 ][2 ] = c + temp[2 ] * axis[2 ];
51
+ namespace impl
52
+ {
53
+ template<uint16_t MOut, uint16_t MIn , typename T>
54
+ struct zero_expand_helper
55
+ {
56
+ static vector <T, MOut> __call (vector <T, MIn > inVec)
57
+ {
58
+ return vector <T, MOut>(inVec, vector <T, MOut - MIn >(0 ));
59
+ }
60
+ };
61
+ template<uint16_t M, typename T>
62
+ struct zero_expand_helper<M,M,T>
63
+ {
64
+ static vector <T, M> __call (vector <T, M> inVec)
65
+ {
66
+ return inVec;
67
+ }
68
+ };
69
+ }
49
70
50
- return rotation;
71
+ template<uint16_t MOut, uint16_t MIn , typename T NBL_FUNC_REQUIRES (MOut >= MIn )
72
+ vector <T, MOut> zero_expand (vector <T, MIn > inVec)
73
+ {
74
+ return impl::zero_expand_helper<MOut, MIn , T>::__call (inVec);
51
75
}
52
76
53
- template <uint16_t NOut, uint16_t MOut, uint16_t NIn, uint16_t MIn , typename T>
54
- requires (NOut >= NIn && MOut >= MIn )
55
- matrix <T, NOut, MOut> promote_affine (const matrix <T, NIn, MIn > inMatrix)
77
+ template <uint16_t NOut, uint16_t MOut, uint16_t NIn, uint16_t MIn , typename T NBL_FUNC_REQUIRES (NOut >= NIn && MOut >= MIn )
78
+ matrix <T, NOut, MOut> promote_affine (const matrix <T, NIn, MIn > inMatrix)
56
79
{
57
- matrix <T, NOut, MOut> retval;
80
+ matrix <T, NOut, MOut> retval;
58
81
59
- using out_row_t = hlsl::vector <T, MOut>;
60
- auto expandVec = [](vector <T, MIn > inVec) -> vector <T, MOut>
61
- {
62
- if constexpr (MIn == MOut) return inVec;
63
- return vector <T, MOut>(inVec, vector <T, MOut - MIn >(0 ));
64
- };
82
+ using out_row_t = hlsl::vector <T, MOut>;
65
83
66
- for (auto row_i = 0u ; row_i < NOut; row_i++)
67
- {
68
- retval[row_i] = row_i < NIn ? expandVec (inMatrix[row_i]) : promote<out_row_t>( 0.0f );
69
- if ((row_i >= NIn || row_i >= MIn ) && row_i < MOut) retval[row_i][row_i] = T ( 1 );
70
-
71
- }
72
- return retval;
84
+ for (uint32_t row_i = 0 ; row_i < NOut; row_i++)
85
+ {
86
+ retval[row_i] = hlsl:: mix (promote<out_row_t>( 0.0 ), zero_expand<MOut, MIn > (inMatrix[row_i]), row_i < NIn );
87
+ if ((row_i >= NIn || row_i >= MIn ) && row_i < MOut)
88
+ retval[row_i][row_i] = T ( 1.0 );
89
+ }
90
+ return retval;
73
91
}
74
92
75
93
}
0 commit comments