@@ -129,6 +129,9 @@ MslShaderGenerator::MslShaderGenerator(TypeSystemPtr typeSystem) :
129129 _lightSamplingNodes.push_back (ShaderNode::create (nullptr , " sampleLightSource" , HwLightSamplerNode::create ()));
130130
131131 _tokenSubstitutions[HW::T_CLOSURE_DATA_CONSTRUCTOR] = " {closureType, L, V, N, P, occlusion}" ;
132+
133+ _tokenSubstitutions[HW::T_TEX_SAMPLER_SAMPLER2D] = HW::TEX_SAMPLER_SAMPLER2D_MSL;
134+ _tokenSubstitutions[HW::T_TEX_SAMPLER_SIGNATURE] = HW::TEX_SAMPLER_SIGNATURE_MSL;
132135}
133136
134137ShaderPtr MslShaderGenerator::generate (const string& name, ElementPtr element, GenContext& context) const
@@ -157,6 +160,8 @@ ShaderPtr MslShaderGenerator::generate(const string& name, ElementPtr element, G
157160 emitVertexStage (shader->getGraph (), context, vs);
158161 replaceTokens (_tokenSubstitutions, vs);
159162
163+ MetalizeGeneratedShader (vs);
164+
160165 // Emit code for pixel shader stage
161166 ShaderStage& ps = shader->getStage (Stage::PIXEL);
162167 emitPixelStage (shader->getGraph (), context, ps);
@@ -222,7 +227,7 @@ void MslShaderGenerator::MetalizeGeneratedShader(ShaderStage& shaderStage) const
222227 }
223228 else
224229 {
225- sourceCode.replace (beg, typename_end - beg, " thread " + typeName + " &" );
230+ sourceCode.replace (beg, typename_end - beg, " thread " + typeName + " &" );
226231 }
227232 }
228233 pos = sourceCode.find (keyword, pos);
@@ -236,14 +241,30 @@ void MslShaderGenerator::MetalizeGeneratedShader(ShaderStage& shaderStage) const
236241 replaceTokens[" dFdy" ] = " dfdy" ;
237242 replaceTokens[" dFdx" ] = " dfdx" ;
238243
244+ replaceTokens[" vec2" ] = " float2" ;
245+ replaceTokens[" vec3" ] = " float3" ;
246+ replaceTokens[" vec4" ] = " float4" ;
247+ replaceTokens[" ivec2" ] = " int2" ;
248+ replaceTokens[" ivec3" ] = " int3" ;
249+ replaceTokens[" ivec4" ] = " int4" ;
250+ replaceTokens[" uvec2" ] = " uint2" ;
251+ replaceTokens[" uvec3" ] = " uint3" ;
252+ replaceTokens[" uvec4" ] = " uint4" ;
253+ replaceTokens[" bvec2" ] = " bool2" ;
254+ replaceTokens[" bvec3" ] = " bool3" ;
255+ replaceTokens[" bvec4" ] = " bool4" ;
256+ replaceTokens[" mat2" ] = " float2x2" ;
257+ replaceTokens[" mat3" ] = " float3x3" ;
258+ replaceTokens[" mat4" ] = " float4x4" ;
259+
239260 auto isAllowedAfterToken = [](char ch) -> bool
240261 {
241262 return std::isspace (ch) || ch == ' (' || ch == ' )' || ch == ' ,' ;
242263 };
243264
244265 auto isAllowedBeforeToken = [](char ch) -> bool
245266 {
246- return std::isspace (ch) || ch == ' (' || ch == ' ,' || ch == ' -' ;
267+ return std::isspace (ch) || ch == ' (' || ch == ' ,' || ch == ' + ' || ch == ' -' ;
247268 };
248269
249270 for (const auto & t : replaceTokens)
@@ -295,7 +316,7 @@ void MslShaderGenerator::emitGlobalVariables(GenContext& context,
295316 {
296317 if (globalContextMembers)
297318 {
298- emitLine (" vec4 gl_FragCoord" , stage);
319+ emitLine (" float4 gl_FragCoord" , stage);
299320 }
300321 if (globalContextConstructorInit)
301322 {
@@ -593,9 +614,9 @@ void MslShaderGenerator::emitVertexStage(const ShaderGraph& graph, GenContext& c
593614 emitString (" \t GlobalContext ctx {" , stage);
594615 emitGlobalVariables (context, stage, EMIT_GLOBAL_SCOPE_CONTEXT_MEMBER_INIT, true , false );
595616 emitLine (" }" , stage, true );
596- emitLine (vertexData.getName () + " out = ctx.VertexMain()" , stage, true );
597- emitLine (" out .pos.y = -out .pos.y" , stage, true );
598- emitLine (" return out " , stage, true );
617+ emitLine (vertexData.getName () + " outVertex = ctx.VertexMain()" , stage, true );
618+ emitLine (" outVertex .pos.y = -outVertex .pos.y" , stage, true );
619+ emitLine (" return outVertex " , stage, true );
599620 }
600621 emitScopeEnd (stage);
601622 emitLineBreak (stage);
@@ -650,22 +671,6 @@ void MslShaderGenerator::emitDirectives(GenContext&, ShaderStage& stage) const
650671 emitLine (" #include <simd/simd.h>" , stage, false );
651672 emitLine (" using namespace metal;" , stage, false );
652673
653- emitLine (" #define vec2 float2" , stage, false );
654- emitLine (" #define vec3 float3" , stage, false );
655- emitLine (" #define vec4 float4" , stage, false );
656- emitLine (" #define ivec2 int2" , stage, false );
657- emitLine (" #define ivec3 int3" , stage, false );
658- emitLine (" #define ivec4 int4" , stage, false );
659- emitLine (" #define uvec2 uint2" , stage, false );
660- emitLine (" #define uvec3 uint3" , stage, false );
661- emitLine (" #define uvec4 uint4" , stage, false );
662- emitLine (" #define bvec2 bool2" , stage, false );
663- emitLine (" #define bvec3 bool3" , stage, false );
664- emitLine (" #define bvec4 bool4" , stage, false );
665- emitLine (" #define mat2 float2x2" , stage, false );
666- emitLine (" #define mat3 float3x3" , stage, false );
667- emitLine (" #define mat4 float4x4" , stage, false );
668-
669674 emitLineBreak (stage);
670675}
671676
0 commit comments