Skip to content

Commit 55e8204

Browse files
authored
Improve Metal syntax (#2622)
The current Metal shader generator is highly derived from the Glsl shader generator. We then use a function (`MetalizeGeneratedShader`) to string process the GLSL source to convert it to legal Metal shader source. Currently there is still a lot of GLSL syntax left, and we rely on `#define` statements and the Metal compiler preprocessor to create correct syntax. This PR does the following: * Adds more string processing to `MetalizeGeneratedShader` replacing the preprocessor defines. * Correct some incorrect Syntax registrations, that were still emitting Glsl types. * Fix `mx_math.metal` to be legal Metal syntax, and not rely on the `MetalizeGeneratedShader` function. * Register token substitutions for `T_TEX_SAMPLER_SAMPLER2D` and `T_TEX_SAMPLER_SIGNATURE` specific to the required Metal syntax. Ideally all metal source files should be legal metal syntax. This work is a pre-cursor to other work to improve Metal support, and it would be helpful to get this piece merged and tested to make that work easier moving forwards.
1 parent 5d9df63 commit 55e8204

6 files changed

Lines changed: 77 additions & 68 deletions

File tree

libraries/stdlib/genmsl/lib/mx_math.metal

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,27 @@
66
#define mx_asin metal::asin
77
#define mx_acos metal::acos
88

9-
vec2 mx_matrix_mul(vec2 v, mat2 m) { return v * m; }
10-
vec3 mx_matrix_mul(vec3 v, mat3 m) { return v * m; }
11-
vec4 mx_matrix_mul(vec4 v, mat4 m) { return v * m; }
12-
vec2 mx_matrix_mul(mat2 m, vec2 v) { return m * v; }
13-
vec3 mx_matrix_mul(mat3 m, vec3 v) { return m * v; }
14-
vec4 mx_matrix_mul(mat4 m, vec4 v) { return m * v; }
15-
mat2 mx_matrix_mul(mat2 m1, mat2 m2) { return m1 * m2; }
16-
mat3 mx_matrix_mul(mat3 m1, mat3 m2) { return m1 * m2; }
17-
mat4 mx_matrix_mul(mat4 m1, mat4 m2) { return m1 * m2; }
9+
float2 mx_matrix_mul(float2 v, float2x2 m) { return v * m; }
10+
float3 mx_matrix_mul(float3 v, float3x3 m) { return v * m; }
11+
float4 mx_matrix_mul(float4 v, float4x4 m) { return v * m; }
12+
float2 mx_matrix_mul(float2x2 m, float2 v) { return m * v; }
13+
float3 mx_matrix_mul(float3x3 m, float3 v) { return m * v; }
14+
float4 mx_matrix_mul(float4x4 m, float4 v) { return m * v; }
15+
float2x2 mx_matrix_mul(float2x2 m1, float2x2 m2) { return m1 * m2; }
16+
float3x3 mx_matrix_mul(float3x3 m1, float3x3 m2) { return m1 * m2; }
17+
float4x4 mx_matrix_mul(float4x4 m1, float4x4 m2) { return m1 * m2; }
1818

1919
float mx_square(float x)
2020
{
2121
return x*x;
2222
}
2323

24-
vec2 mx_square(vec2 x)
24+
float2 mx_square(float2 x)
2525
{
2626
return x*x;
2727
}
2828

29-
vec3 mx_square(vec3 x)
29+
float3 mx_square(float3 x)
3030
{
3131
return x*x;
3232
}
@@ -118,17 +118,17 @@ float mx_atan(float y, float x)
118118
return metal::atan2(y, x);
119119
}
120120

121-
vec2 mx_atan(vec2 y, vec2 x)
121+
float2 mx_atan(float2 y, float2 x)
122122
{
123123
return metal::atan2(y, x);
124124
}
125125

126-
vec3 mx_atan(vec3 y, vec3 x)
126+
float3 mx_atan(float3 y, float3 x)
127127
{
128128
return metal::atan2(y, x);
129129
}
130130

131-
vec4 mx_atan(vec4 y, vec4 x)
131+
float4 mx_atan(float4 y, float4 x)
132132
{
133133
return metal::atan2(y, x);
134134
}
@@ -138,7 +138,7 @@ float mx_radians(float degree)
138138
return (degree * M_PI_F / 180.0f);
139139
}
140140

141-
vec2 mx_radians(vec2 degree)
141+
float2 mx_radians(float2 degree)
142142
{
143143
return (degree * M_PI_F / 180.0f);
144144
}

source/MaterialXGenHw/HwConstants.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,10 @@ const string ENV_IRRADIANCE_SAMPLER2D_SPLIT = "sampler2D(u_envIradiance_textur
126126

127127
const string TEX_SAMPLER_SAMPLER2D = "tex_sampler";
128128
const string TEX_SAMPLER_SAMPLER2D_SPLIT = "sampler2D(tex_texture, tex_sampler)";
129+
const string TEX_SAMPLER_SAMPLER2D_MSL = "tex_sampler";
129130
const string TEX_SAMPLER_SIGNATURE = "sampler2D tex_sampler";
130131
const string TEX_SAMPLER_SIGNATURE_SPLIT = "texture2D tex_texture, sampler tex_sampler";
132+
const string TEX_SAMPLER_SIGNATURE_MSL = "MetalTexture tex_sampler";
131133

132134
const string ENV_LIGHT_INTENSITY = "u_envLightIntensity";
133135
const string ENV_PREFILTER_MIP = "u_envPrefilterMip";

source/MaterialXGenHw/HwConstants.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,10 @@ extern MX_GENHW_API const string LIGHT_DATA_MAX_LIGHT_SOURCES;
213213
/// Texture sampler parameters (for both combined and separate values)
214214
extern MX_GENHW_API const string TEX_SAMPLER_SAMPLER2D;
215215
extern MX_GENHW_API const string TEX_SAMPLER_SAMPLER2D_SPLIT;
216+
extern MX_GENHW_API const string TEX_SAMPLER_SAMPLER2D_MSL;
216217
extern MX_GENHW_API const string TEX_SAMPLER_SIGNATURE;
217218
extern MX_GENHW_API const string TEX_SAMPLER_SIGNATURE_SPLIT;
219+
extern MX_GENHW_API const string TEX_SAMPLER_SIGNATURE_MSL;
218220

219221
/// Variable blocks names.
220222
extern MX_GENHW_API const string VERTEX_INPUTS; // Geometric inputs for vertex stage.

source/MaterialXGenMsl/MslShaderGenerator.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ MslShaderGenerator::MslShaderGenerator(TypeSystemPtr typeSystem) :
129129
_lightSamplingNodes.push_back(ShaderNode::create(nullptr, "sampleLightSource", HwLightSamplerNode::create()));
130130

131131
_tokenSubstitutions[HW::T_CLOSURE_DATA_CONSTRUCTOR] = "{closureType, L, V, N, P, occlusion}";
132+
133+
_tokenSubstitutions[HW::T_TEX_SAMPLER_SAMPLER2D] = HW::TEX_SAMPLER_SAMPLER2D_MSL;
134+
_tokenSubstitutions[HW::T_TEX_SAMPLER_SIGNATURE] = HW::TEX_SAMPLER_SIGNATURE_MSL;
132135
}
133136

134137
ShaderPtr MslShaderGenerator::generate(const string& name, ElementPtr element, GenContext& context) const
@@ -157,6 +160,8 @@ ShaderPtr MslShaderGenerator::generate(const string& name, ElementPtr element, G
157160
emitVertexStage(shader->getGraph(), context, vs);
158161
replaceTokens(_tokenSubstitutions, vs);
159162

163+
MetalizeGeneratedShader(vs);
164+
160165
// Emit code for pixel shader stage
161166
ShaderStage& ps = shader->getStage(Stage::PIXEL);
162167
emitPixelStage(shader->getGraph(), context, ps);
@@ -222,7 +227,7 @@ void MslShaderGenerator::MetalizeGeneratedShader(ShaderStage& shaderStage) const
222227
}
223228
else
224229
{
225-
sourceCode.replace(beg, typename_end - beg, "thread " + typeName + "&");
230+
sourceCode.replace(beg, typename_end - beg, "thread " + typeName + " &");
226231
}
227232
}
228233
pos = sourceCode.find(keyword, pos);
@@ -236,14 +241,30 @@ void MslShaderGenerator::MetalizeGeneratedShader(ShaderStage& shaderStage) const
236241
replaceTokens["dFdy"] = "dfdy";
237242
replaceTokens["dFdx"] = "dfdx";
238243

244+
replaceTokens["vec2"] = "float2";
245+
replaceTokens["vec3"] = "float3";
246+
replaceTokens["vec4"] = "float4";
247+
replaceTokens["ivec2"] = "int2";
248+
replaceTokens["ivec3"] = "int3";
249+
replaceTokens["ivec4"] = "int4";
250+
replaceTokens["uvec2"] = "uint2";
251+
replaceTokens["uvec3"] = "uint3";
252+
replaceTokens["uvec4"] = "uint4";
253+
replaceTokens["bvec2"] = "bool2";
254+
replaceTokens["bvec3"] = "bool3";
255+
replaceTokens["bvec4"] = "bool4";
256+
replaceTokens["mat2"] = "float2x2";
257+
replaceTokens["mat3"] = "float3x3";
258+
replaceTokens["mat4"] = "float4x4";
259+
239260
auto isAllowedAfterToken = [](char ch) -> bool
240261
{
241262
return std::isspace(ch) || ch == '(' || ch == ')' || ch == ',';
242263
};
243264

244265
auto isAllowedBeforeToken = [](char ch) -> bool
245266
{
246-
return std::isspace(ch) || ch == '(' || ch == ',' || ch == '-';
267+
return std::isspace(ch) || ch == '(' || ch == ',' || ch == '+' || ch == '-';
247268
};
248269

249270
for (const auto& t : replaceTokens)
@@ -295,7 +316,7 @@ void MslShaderGenerator::emitGlobalVariables(GenContext& context,
295316
{
296317
if (globalContextMembers)
297318
{
298-
emitLine("vec4 gl_FragCoord", stage);
319+
emitLine("float4 gl_FragCoord", stage);
299320
}
300321
if (globalContextConstructorInit)
301322
{
@@ -593,9 +614,9 @@ void MslShaderGenerator::emitVertexStage(const ShaderGraph& graph, GenContext& c
593614
emitString("\tGlobalContext ctx {", stage);
594615
emitGlobalVariables(context, stage, EMIT_GLOBAL_SCOPE_CONTEXT_MEMBER_INIT, true, false);
595616
emitLine("}", stage, true);
596-
emitLine(vertexData.getName() + " out = ctx.VertexMain()", stage, true);
597-
emitLine("out.pos.y = -out.pos.y", stage, true);
598-
emitLine("return out", stage, true);
617+
emitLine(vertexData.getName() + " outVertex = ctx.VertexMain()", stage, true);
618+
emitLine("outVertex.pos.y = -outVertex.pos.y", stage, true);
619+
emitLine("return outVertex", stage, true);
599620
}
600621
emitScopeEnd(stage);
601622
emitLineBreak(stage);
@@ -650,22 +671,6 @@ void MslShaderGenerator::emitDirectives(GenContext&, ShaderStage& stage) const
650671
emitLine("#include <simd/simd.h>", stage, false);
651672
emitLine("using namespace metal;", stage, false);
652673

653-
emitLine("#define vec2 float2", stage, false);
654-
emitLine("#define vec3 float3", stage, false);
655-
emitLine("#define vec4 float4", stage, false);
656-
emitLine("#define ivec2 int2", stage, false);
657-
emitLine("#define ivec3 int3", stage, false);
658-
emitLine("#define ivec4 int4", stage, false);
659-
emitLine("#define uvec2 uint2", stage, false);
660-
emitLine("#define uvec3 uint3", stage, false);
661-
emitLine("#define uvec4 uint4", stage, false);
662-
emitLine("#define bvec2 bool2", stage, false);
663-
emitLine("#define bvec3 bool3", stage, false);
664-
emitLine("#define bvec4 bool4", stage, false);
665-
emitLine("#define mat2 float2x2", stage, false);
666-
emitLine("#define mat3 float3x3", stage, false);
667-
emitLine("#define mat4 float4x4", stage, false);
668-
669674
emitLineBreak(stage);
670675
}
671676

source/MaterialXGenMsl/MslSyntax.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ MslSyntax::MslSyntax(TypeSystemPtr typeSystem) : Syntax(typeSystem)
176176
Type::COLOR3,
177177
std::make_shared<AggregateTypeSyntax>(
178178
this,
179-
"vec3",
180-
"vec3(0.0)",
181-
"vec3(0.0)",
179+
"float3",
180+
"float3(0.0)",
181+
"float3(0.0)",
182182
EMPTY_STRING,
183183
EMPTY_STRING,
184184
VEC3_MEMBERS));
@@ -187,9 +187,9 @@ MslSyntax::MslSyntax(TypeSystemPtr typeSystem) : Syntax(typeSystem)
187187
Type::COLOR4,
188188
std::make_shared<AggregateTypeSyntax>(
189189
this,
190-
"vec4",
191-
"vec4(0.0)",
192-
"vec4(0.0)",
190+
"float4",
191+
"float4(0.0)",
192+
"float4(0.0)",
193193
EMPTY_STRING,
194194
EMPTY_STRING,
195195
VEC4_MEMBERS));
@@ -198,9 +198,9 @@ MslSyntax::MslSyntax(TypeSystemPtr typeSystem) : Syntax(typeSystem)
198198
Type::VECTOR2,
199199
std::make_shared<AggregateTypeSyntax>(
200200
this,
201-
"vec2",
202-
"vec2(0.0)",
203-
"vec2(0.0)",
201+
"float2",
202+
"float2(0.0)",
203+
"float2(0.0)",
204204
EMPTY_STRING,
205205
EMPTY_STRING,
206206
VEC2_MEMBERS));
@@ -209,9 +209,9 @@ MslSyntax::MslSyntax(TypeSystemPtr typeSystem) : Syntax(typeSystem)
209209
Type::VECTOR3,
210210
std::make_shared<AggregateTypeSyntax>(
211211
this,
212-
"vec3",
213-
"vec3(0.0)",
214-
"vec3(0.0)",
212+
"float3",
213+
"float3(0.0)",
214+
"float3(0.0)",
215215
EMPTY_STRING,
216216
EMPTY_STRING,
217217
VEC3_MEMBERS));
@@ -220,9 +220,9 @@ MslSyntax::MslSyntax(TypeSystemPtr typeSystem) : Syntax(typeSystem)
220220
Type::VECTOR4,
221221
std::make_shared<AggregateTypeSyntax>(
222222
this,
223-
"vec4",
224-
"vec4(0.0)",
225-
"vec4(0.0)",
223+
"float4",
224+
"float4(0.0)",
225+
"float4(0.0)",
226226
EMPTY_STRING,
227227
EMPTY_STRING,
228228
VEC4_MEMBERS));
@@ -231,17 +231,17 @@ MslSyntax::MslSyntax(TypeSystemPtr typeSystem) : Syntax(typeSystem)
231231
Type::MATRIX33,
232232
std::make_shared<AggregateTypeSyntax>(
233233
this,
234-
"mat3",
235-
"mat3(1.0)",
236-
"mat3(1.0)"));
234+
"float3x3",
235+
"float3x3(1.0)",
236+
"float3x3(1.0)"));
237237

238238
registerTypeSyntax(
239239
Type::MATRIX44,
240240
std::make_shared<AggregateTypeSyntax>(
241241
this,
242-
"mat4",
243-
"mat4(1.0)",
244-
"mat4(1.0)"));
242+
"float4x4",
243+
"float4x4(1.0)",
244+
"float4x4(1.0)"));
245245

246246
registerTypeSyntax(
247247
Type::STRING,

source/MaterialXTest/MaterialXGenMsl/GenMsl.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ TEST_CASE("GenShader: MSL Syntax Check", "[genmsl]")
2626
mx::SyntaxPtr syntax = mx::MslSyntax::create(ts);
2727

2828
REQUIRE(syntax->getTypeName(mx::Type::FLOAT) == "float");
29-
REQUIRE(syntax->getTypeName(mx::Type::COLOR3) == "vec3");
30-
REQUIRE(syntax->getTypeName(mx::Type::VECTOR3) == "vec3");
29+
REQUIRE(syntax->getTypeName(mx::Type::COLOR3) == "float3");
30+
REQUIRE(syntax->getTypeName(mx::Type::VECTOR3) == "float3");
3131
REQUIRE(syntax->getTypeName(mx::Type::BSDF) == "BSDF");
3232
REQUIRE(syntax->getOutputTypeName(mx::Type::BSDF) == "thread BSDF&");
3333

@@ -38,13 +38,13 @@ TEST_CASE("GenShader: MSL Syntax Check", "[genmsl]")
3838
value = syntax->getDefaultValue(mx::Type::FLOAT);
3939
REQUIRE(value == "0.0");
4040
value = syntax->getDefaultValue(mx::Type::COLOR3);
41-
REQUIRE(value == "vec3(0.0)");
41+
REQUIRE(value == "float3(0.0)");
4242
value = syntax->getDefaultValue(mx::Type::COLOR3, true);
43-
REQUIRE(value == "vec3(0.0)");
43+
REQUIRE(value == "float3(0.0)");
4444
value = syntax->getDefaultValue(mx::Type::COLOR4);
45-
REQUIRE(value == "vec4(0.0)");
45+
REQUIRE(value == "float4(0.0)");
4646
value = syntax->getDefaultValue(mx::Type::COLOR4, true);
47-
REQUIRE(value == "vec4(0.0)");
47+
REQUIRE(value == "float4(0.0)");
4848
value = syntax->getDefaultValue(mx::Type::FLOATARRAY, true);
4949
REQUIRE(value.empty());
5050
value = syntax->getDefaultValue(mx::Type::INTEGERARRAY, true);
@@ -58,15 +58,15 @@ TEST_CASE("GenShader: MSL Syntax Check", "[genmsl]")
5858

5959
mx::ValuePtr color3Value = mx::Value::createValue<mx::Color3>(mx::Color3(1.0f, 2.0f, 3.0f));
6060
value = syntax->getValue(mx::Type::COLOR3, *color3Value);
61-
REQUIRE(value == "vec3(1.0, 2.0, 3.0)");
61+
REQUIRE(value == "float3(1.0, 2.0, 3.0)");
6262
value = syntax->getValue(mx::Type::COLOR3, *color3Value, true);
63-
REQUIRE(value == "vec3(1.0, 2.0, 3.0)");
63+
REQUIRE(value == "float3(1.0, 2.0, 3.0)");
6464

6565
mx::ValuePtr color4Value = mx::Value::createValue<mx::Color4>(mx::Color4(1.0f, 2.0f, 3.0f, 4.0f));
6666
value = syntax->getValue(mx::Type::COLOR4, *color4Value);
67-
REQUIRE(value == "vec4(1.0, 2.0, 3.0, 4.0)");
67+
REQUIRE(value == "float4(1.0, 2.0, 3.0, 4.0)");
6868
value = syntax->getValue(mx::Type::COLOR4, *color4Value, true);
69-
REQUIRE(value == "vec4(1.0, 2.0, 3.0, 4.0)");
69+
REQUIRE(value == "float4(1.0, 2.0, 3.0, 4.0)");
7070

7171
std::vector<float> floatArray = { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f };
7272
mx::ValuePtr floatArrayValue = mx::Value::createValue<std::vector<float>>(floatArray);

0 commit comments

Comments
 (0)