Skip to content

Commit 0fd40b1

Browse files
authored
Consolidate ShaderGenerator::getImplementation (AcademySoftwareFoundation#2604)
While thinking about AcademySoftwareFoundation#2603 - and how to refactor to avoid storing the node definition name in the `ShaderNode`, I noticed that `getImplementation()` is very similar, and in some cases identical across different shader generators. This PR consolidates the common implementation to the base `ShaderGenerator` class, and introduces smaller specialization points where necessary. The `getImplementation()` function is left `virtual` for now, as there may be downstream generators that are overriding it, but in a future breaking change we should probably consider changing that. The GLSL and MSL implementations were actually identical, so I moved that to the HwShaderGenerator common base class.
1 parent 29acc7b commit 0fd40b1

10 files changed

Lines changed: 76 additions & 213 deletions

File tree

source/MaterialXGenGlsl/GlslShaderGenerator.cpp

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <MaterialXGenGlsl/GlslSyntax.h>
99

10+
#include <MaterialXGenShader/Shader.h>
1011
#include <MaterialXGenShader/Nodes/MaterialNode.h>
1112
#include <MaterialXGenShader/Nodes/HwImageNode.h>
1213
#include <MaterialXGenShader/Nodes/HwGeomColorNode.h>
@@ -20,7 +21,6 @@
2021
#include <MaterialXGenShader/Nodes/HwFrameNode.h>
2122
#include <MaterialXGenShader/Nodes/HwTimeNode.h>
2223
#include <MaterialXGenShader/Nodes/HwViewDirectionNode.h>
23-
#include <MaterialXGenShader/Nodes/HwLightCompoundNode.h>
2424
#include <MaterialXGenShader/Nodes/HwLightNode.h>
2525
#include <MaterialXGenShader/Nodes/HwLightSamplerNode.h>
2626
#include <MaterialXGenShader/Nodes/HwLightShaderNode.h>
@@ -692,70 +692,4 @@ void GlslShaderGenerator::emitVariableDeclaration(const ShaderPort* variable, co
692692
}
693693
}
694694

695-
ShaderNodeImplPtr GlslShaderGenerator::getImplementation(const NodeDef& nodedef, GenContext& context) const
696-
{
697-
InterfaceElementPtr implElement = nodedef.getImplementation(getTarget());
698-
if (!implElement)
699-
{
700-
return nullptr;
701-
}
702-
703-
const string& name = implElement->getName();
704-
705-
// Check if it's created and cached already.
706-
ShaderNodeImplPtr impl = context.findNodeImplementation(name);
707-
if (impl)
708-
{
709-
return impl;
710-
}
711-
712-
vector<OutputPtr> outputs = nodedef.getActiveOutputs();
713-
if (outputs.empty())
714-
{
715-
throw ExceptionShaderGenError("NodeDef '" + nodedef.getName() + "' has no outputs defined");
716-
}
717-
718-
const TypeDesc outputType = context.getTypeDesc(outputs[0]->getType());
719-
720-
if (implElement->isA<NodeGraph>())
721-
{
722-
// Use a compound implementation.
723-
if (outputType == Type::LIGHTSHADER)
724-
{
725-
impl = HwLightCompoundNode::create();
726-
}
727-
else
728-
{
729-
impl = CompoundNode::create();
730-
}
731-
}
732-
else if (implElement->isA<Implementation>())
733-
{
734-
if (getColorManagementSystem() && getColorManagementSystem()->hasImplementation(name))
735-
{
736-
impl = getColorManagementSystem()->createImplementation(name);
737-
}
738-
else
739-
{
740-
// Try creating a new in the factory.
741-
impl = _implFactory.create(name);
742-
}
743-
if (!impl)
744-
{
745-
impl = SourceCodeNode::create();
746-
}
747-
}
748-
if (!impl)
749-
{
750-
return nullptr;
751-
}
752-
753-
impl->initialize(*implElement, context);
754-
755-
// Cache it.
756-
context.addNodeImplementation(name, impl);
757-
758-
return impl;
759-
}
760-
761695
MATERIALX_NAMESPACE_END

source/MaterialXGenGlsl/GlslShaderGenerator.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,6 @@ class MX_GENGLSL_API GlslShaderGenerator : public HwShaderGenerator
4949
void emitVariableDeclaration(const ShaderPort* variable, const string& qualifier, GenContext& context, ShaderStage& stage,
5050
bool assignValue = true) const override;
5151

52-
/// Return a registered shader node implementation given an implementation element.
53-
/// The element must be an Implementation or a NodeGraph acting as implementation.
54-
ShaderNodeImplPtr getImplementation(const NodeDef& nodedef, GenContext& context) const override;
55-
5652
/// Determine the prefix of vertex data variables.
5753
string getVertexDataPrefix(const VariableBlock& vertexData) const override;
5854

source/MaterialXGenMdl/MdlShaderGenerator.cpp

Lines changed: 27 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -329,23 +329,8 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G
329329
return shader;
330330
}
331331

332-
ShaderNodeImplPtr MdlShaderGenerator::getImplementation(const NodeDef& nodedef, GenContext& context) const
332+
ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const NodeDef& nodedef) const
333333
{
334-
InterfaceElementPtr implElement = nodedef.getImplementation(getTarget());
335-
if (!implElement)
336-
{
337-
return nullptr;
338-
}
339-
340-
const string& name = implElement->getName();
341-
342-
// Check if it's created and cached already.
343-
ShaderNodeImplPtr impl = context.findNodeImplementation(name);
344-
if (impl)
345-
{
346-
return impl;
347-
}
348-
349334
vector<OutputPtr> outputs = nodedef.getActiveOutputs();
350335
if (outputs.empty())
351336
{
@@ -354,62 +339,38 @@ ShaderNodeImplPtr MdlShaderGenerator::getImplementation(const NodeDef& nodedef,
354339

355340
const TypeDesc outputType = _typeSystem->getType(outputs[0]->getType());
356341

357-
if (implElement->isA<NodeGraph>())
342+
ShaderNodeImplPtr impl;
343+
// Use a compound implementation.
344+
if (outputType.isClosure())
358345
{
359-
// Use a compound implementation.
360-
if (outputType.isClosure())
361-
{
362-
impl = ClosureCompoundNodeMdl::create();
363-
}
364-
else
365-
{
366-
impl = CompoundNodeMdl::create();
367-
}
368-
}
369-
else if (implElement->isA<Implementation>())
370-
{
371-
if (getColorManagementSystem() && getColorManagementSystem()->hasImplementation(name))
372-
{
373-
impl = getColorManagementSystem()->createImplementation(name);
374-
}
375-
else
376-
{
377-
// Try creating a new in the factory.
378-
impl = _implFactory.create(name);
379-
}
380-
if (!impl)
381-
{
382-
// When `file` and `function` are provided we consider this node a user node
383-
const string file = implElement->getTypedAttribute<string>("file");
384-
const string function = implElement->getTypedAttribute<string>("function");
385-
// Or, if `sourcecode` is provided we consider this node a user node with inline implementation
386-
// inline implementations are not supposed to have replacement markers
387-
const string sourcecode = implElement->getTypedAttribute<string>("sourcecode");
388-
if ((!file.empty() && !function.empty()) || (!sourcecode.empty() && sourcecode.find("{{") == string::npos))
389-
{
390-
impl = CustomCodeNodeMdl::create();
391-
}
392-
else if (file.empty() && sourcecode.empty())
393-
{
394-
throw ExceptionShaderGenError("No valid MDL implementation found for '" + name + "'");
395-
}
396-
else
397-
{
398-
impl = SourceCodeNodeMdl::create();
399-
}
400-
}
346+
return ClosureCompoundNodeMdl::create();
401347
}
402-
if (!impl)
348+
return CompoundNodeMdl::create();
349+
}
350+
351+
ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForImplementation(const NodeDef& nodedef) const
352+
{
353+
InterfaceElementPtr implElement = nodedef.getImplementation(getTarget());
354+
if (!implElement)
403355
{
404356
return nullptr;
405357
}
406358

407-
impl->initialize(*implElement, context);
408-
409-
// Cache it.
410-
context.addNodeImplementation(name, impl);
411-
412-
return impl;
359+
// When `file` and `function` are provided we consider this node a user node
360+
const string file = implElement->getTypedAttribute<string>("file");
361+
const string function = implElement->getTypedAttribute<string>("function");
362+
// Or, if `sourcecode` is provided we consider this node a user node with inline implementation
363+
// inline implementations are not supposed to have replacement markers
364+
const string sourcecode = implElement->getTypedAttribute<string>("sourcecode");
365+
if ((!file.empty() && !function.empty()) || (!sourcecode.empty() && sourcecode.find("{{") == string::npos))
366+
{
367+
return CustomCodeNodeMdl::create();
368+
}
369+
if (file.empty() && sourcecode.empty())
370+
{
371+
throw ExceptionShaderGenError("No valid MDL implementation found for '" + implElement->getName() + "'");
372+
}
373+
return SourceCodeNodeMdl::create();
413374
}
414375

415376
string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContext& context) const

source/MaterialXGenMdl/MdlShaderGenerator.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ class MX_GENMDL_API MdlShaderGenerator : public ShaderGenerator
7474
/// the element and all dependencies upstream into shader code.
7575
ShaderPtr generate(const string& name, ElementPtr element, GenContext& context) const override;
7676

77-
/// Return a registered shader node implementation given an implementation element.
78-
/// The element must be an Implementation or a NodeGraph acting as implementation.
79-
ShaderNodeImplPtr getImplementation(const NodeDef& nodedef, GenContext& context) const override;
77+
/// Create the shader node implementation for a nodedef that has a NodeGraph implementation.
78+
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeDef& nodedef) const override;
79+
80+
/// Create the shader node implementation for a nodedef that has a Implementation implementation.
81+
ShaderNodeImplPtr createShaderNodeImplForImplementation(const NodeDef& nodedef) const override;
8082

8183
/// Return the result of an upstream connection or value for an input.
8284
string getUpstreamResult(const ShaderInput* input, GenContext& context) const override;

source/MaterialXGenMsl/MslShaderGenerator.cpp

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <MaterialXGenMsl/MslSyntax.h>
99

10+
#include <MaterialXGenShader/Shader.h>
1011
#include <MaterialXGenShader/Nodes/MaterialNode.h>
1112
#include <MaterialXGenShader/Nodes/HwImageNode.h>
1213
#include <MaterialXGenShader/Nodes/HwGeomColorNode.h>
@@ -20,7 +21,6 @@
2021
#include <MaterialXGenShader/Nodes/HwFrameNode.h>
2122
#include <MaterialXGenShader/Nodes/HwTimeNode.h>
2223
#include <MaterialXGenShader/Nodes/HwViewDirectionNode.h>
23-
#include <MaterialXGenShader/Nodes/HwLightCompoundNode.h>
2424
#include <MaterialXGenShader/Nodes/HwLightNode.h>
2525
#include <MaterialXGenShader/Nodes/HwLightSamplerNode.h>
2626
#include <MaterialXGenShader/Nodes/HwLightShaderNode.h>
@@ -1209,70 +1209,4 @@ void MslShaderGenerator::emitVariableDeclaration(const ShaderPort* variable, con
12091209
}
12101210
}
12111211

1212-
ShaderNodeImplPtr MslShaderGenerator::getImplementation(const NodeDef& nodedef, GenContext& context) const
1213-
{
1214-
InterfaceElementPtr implElement = nodedef.getImplementation(getTarget());
1215-
if (!implElement)
1216-
{
1217-
return nullptr;
1218-
}
1219-
1220-
const string& name = implElement->getName();
1221-
1222-
// Check if it's created and cached already.
1223-
ShaderNodeImplPtr impl = context.findNodeImplementation(name);
1224-
if (impl)
1225-
{
1226-
return impl;
1227-
}
1228-
1229-
vector<OutputPtr> outputs = nodedef.getActiveOutputs();
1230-
if (outputs.empty())
1231-
{
1232-
throw ExceptionShaderGenError("NodeDef '" + nodedef.getName() + "' has no outputs defined");
1233-
}
1234-
1235-
const TypeDesc outputType = _typeSystem->getType(outputs[0]->getType());
1236-
1237-
if (implElement->isA<NodeGraph>())
1238-
{
1239-
// Use a compound implementation.
1240-
if (outputType == Type::LIGHTSHADER)
1241-
{
1242-
impl = HwLightCompoundNode::create();
1243-
}
1244-
else
1245-
{
1246-
impl = CompoundNode::create();
1247-
}
1248-
}
1249-
else if (implElement->isA<Implementation>())
1250-
{
1251-
if (getColorManagementSystem() && getColorManagementSystem()->hasImplementation(name))
1252-
{
1253-
impl = getColorManagementSystem()->createImplementation(name);
1254-
}
1255-
else
1256-
{
1257-
// Try creating a new in the factory.
1258-
impl = _implFactory.create(name);
1259-
}
1260-
if (!impl)
1261-
{
1262-
impl = SourceCodeNode::create();
1263-
}
1264-
}
1265-
if (!impl)
1266-
{
1267-
return nullptr;
1268-
}
1269-
1270-
impl->initialize(*implElement, context);
1271-
1272-
// Cache it.
1273-
context.addNodeImplementation(name, impl);
1274-
1275-
return impl;
1276-
}
1277-
12781212
MATERIALX_NAMESPACE_END

source/MaterialXGenMsl/MslShaderGenerator.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ class MX_GENMSL_API MslShaderGenerator : public HwShaderGenerator
5252
void emitVariableDeclaration(const ShaderPort* variable, const string& qualifier, GenContext& context, ShaderStage& stage,
5353
bool assignValue = true) const override;
5454

55-
/// Return a registered shader node implementation given an implementation element.
56-
/// The element must be an Implementation or a NodeGraph acting as implementation.
57-
ShaderNodeImplPtr getImplementation(const NodeDef& nodedef, GenContext& context) const override;
58-
5955
/// Determine the prefix of vertex data variables.
6056
string getVertexDataPrefix(const VariableBlock& vertexData) const override;
6157

source/MaterialXGenShader/HwShaderGenerator.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <MaterialXGenShader/HwShaderGenerator.h>
77
#include <MaterialXGenShader/GenContext.h>
88
#include <MaterialXGenShader/Shader.h>
9+
#include <MaterialXGenShader/Nodes/HwLightCompoundNode.h>
10+
#include <MaterialXGenShader/Nodes/CompoundNode.h>
911

1012
#include <MaterialXCore/Document.h>
1113
#include <MaterialXCore/Definition.h>
@@ -567,9 +569,28 @@ void HwShaderGenerator::addStageLightingUniforms(GenContext& context, ShaderStag
567569
}
568570
}
569571

572+
ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const NodeDef& nodedef) const
573+
{
574+
vector<OutputPtr> outputs = nodedef.getActiveOutputs();
575+
if (outputs.empty())
576+
{
577+
throw ExceptionShaderGenError("NodeDef '" + nodedef.getName() + "' has no outputs defined");
578+
}
579+
580+
const TypeDesc outputType = _typeSystem->getType(outputs[0]->getType());
581+
582+
// Use a compound implementation.
583+
if (outputType == Type::LIGHTSHADER)
584+
{
585+
return HwLightCompoundNode::create();
586+
}
587+
return CompoundNode::create();
588+
}
589+
570590
bool HwImplementation::isEditable(const ShaderInput& input) const
571591
{
572592
return IMMUTABLE_INPUTS.count(input.getName()) == 0;
573593
}
574594

595+
575596
MATERIALX_NAMESPACE_END

source/MaterialXGenShader/HwShaderGenerator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ class MX_GENSHADER_API HwShaderGenerator : public ShaderGenerator
345345
/// Determine the prefix of vertex data variables.
346346
virtual string getVertexDataPrefix(const VariableBlock& vertexData) const = 0;
347347

348+
/// Create the shader node implementation for a nodedef that has a NodeGraph implementation.
349+
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeDef& nodedef) const override;
350+
348351
// Note : the order must match the order defined in libraries/pbrlib/genglsl/lib/mx_closure_type.glsl
349352
// TODO : investigate build time mechanism for ensuring these stay in sync.
350353

0 commit comments

Comments
 (0)