Skip to content

Commit b48cb28

Browse files
authored
Simplify OSL shader generator for surfaceshader output (AcademySoftwareFoundation#2509)
Remove custom OSL code injection in favor of creating an instance of the required node that exists in the data library.
1 parent 95073cb commit b48cb28

3 files changed

Lines changed: 122 additions & 50 deletions

File tree

source/MaterialXGenOsl/OslShaderGenerator.cpp

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,9 @@ ShaderPtr OslShaderGenerator::generate(const string& name, ElementPtr element, G
111111
emitShaderInputs(stage.getInputBlock(OSL::INPUTS), stage);
112112
emitShaderInputs(stage.getUniformBlock(OSL::UNIFORMS), stage);
113113

114-
// Emit shader output
114+
// Emit shader outputs
115115
const VariableBlock& outputs = stage.getOutputBlock(OSL::OUTPUTS);
116-
const ShaderPort* singleOutput = outputs.size() == 1 ? outputs[0] : NULL;
117-
118-
const bool isSurfaceShaderOutput = singleOutput && singleOutput->getType() == Type::SURFACESHADER;
119-
120-
if (isSurfaceShaderOutput)
121-
{
122-
// Special case for having 'surfaceshader' as final output type.
123-
// This type is a struct internally (BSDF, EDF, opacity) so we must
124-
// declare this as a single closure color type in order for renderers
125-
// to understand this output.
126-
emitLine("output closure color " + singleOutput->getVariable() + " = 0", stage, false);
127-
}
128-
else
129-
{
130-
// Just emit all outputs the way they are declared.
131-
emitShaderOutputs(outputs, stage);
132-
}
116+
emitShaderOutputs(outputs, stage);
133117

134118
// End shader signature
135119
emitScopeEnd(stage);
@@ -178,29 +162,11 @@ ShaderPtr OslShaderGenerator::generate(const string& name, ElementPtr element, G
178162
}
179163
}
180164

181-
// Emit final outputs
182-
if (isSurfaceShaderOutput)
183-
{
184-
// Special case for having 'surfaceshader' as final output type.
185-
// This type is a struct internally (BSDF, EDF, opacity) so we must
186-
// convert this to a single closure color type in order for renderers
187-
// to understand this output.
188-
const ShaderGraphOutputSocket* socket = graph.getOutputSocket(0);
189-
const string result = getUpstreamResult(socket, context);
190-
emitScopeBegin(stage);
191-
emitLine("float opacity_weight = clamp(" + result + ".opacity, 0.0, 1.0)", stage);
192-
emitLine(singleOutput->getVariable() + " = (" + result + ".bsdf + " + result + ".edf) * opacity_weight + transparent() * (1.0 - opacity_weight)", stage);
193-
emitScopeEnd(stage);
194-
}
195-
else
165+
// Assign results to final outputs.
166+
for (size_t i = 0; i < outputs.size(); ++i)
196167
{
197-
// Assign results to final outputs.
198-
for (size_t i = 0; i < outputs.size(); ++i)
199-
{
200-
const ShaderGraphOutputSocket* outputSocket = graph.getOutputSocket(i);
201-
const string result = getUpstreamResult(outputSocket, context);
202-
emitLine(outputSocket->getVariable() + " = " + result, stage);
203-
}
168+
const ShaderGraphOutputSocket* outputSocket = graph.getOutputSocket(i);
169+
emitLine(outputSocket->getVariable() + " = " + getUpstreamResult(outputSocket, context), stage);
204170
}
205171

206172
// End shader body
@@ -249,6 +215,20 @@ ShaderPtr OslShaderGenerator::createShader(const string& name, ElementPtr elemen
249215
{
250216
// Create the root shader graph
251217
ShaderGraphPtr graph = ShaderGraph::create(nullptr, name, element, context);
218+
219+
// Special handling for surfaceshader type output - if we have a material
220+
// that outputs a single surfaceshader then we will implicitly add a surfacematerial
221+
// node to create the final closure color - the surfaceshader type is a struct and needs
222+
// flattening to a single closure in the surfacematerial node.
223+
const auto& outputSockets = graph->getOutputSockets();
224+
const auto* singleOutput = outputSockets.size() == 1 ? outputSockets[0] : NULL;
225+
226+
const bool isSurfaceShaderOutput = singleOutput && singleOutput->getType() == Type::SURFACESHADER;
227+
if (isSurfaceShaderOutput)
228+
{
229+
graph->inlineNodeBeforeOutput(outputSockets[0], "_surfacematerial_", "ND_surfacematerial", "surfaceshader", "out", context);
230+
}
231+
252232
ShaderPtr shader = std::make_shared<Shader>(name, graph);
253233

254234
// Create our stage.

source/MaterialXGenShader/ShaderGraph.cpp

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name
629629
}
630630

631631
// Apply color and unit transforms to each input.
632-
graph->applyInputTransforms(node, newNode, context);
632+
graph->applyInputTransforms(node, newNode.get(), context);
633633

634634
// Set root for upstream dependency traversal
635635
root = node;
@@ -651,7 +651,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name
651651
return graph;
652652
}
653653

654-
void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context)
654+
void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNode* shaderNode, GenContext& context)
655655
{
656656
ColorManagementSystemPtr colorManagementSystem = context.getShaderGenerator().getColorManagementSystem();
657657
UnitSystemPtr unitSystem = context.getShaderGenerator().getUnitSystem();
@@ -701,20 +701,29 @@ void ShaderGraph::applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNo
701701
}
702702
}
703703

704-
ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
704+
ShaderNode* ShaderGraph::createNode(const string& name, ConstNodeDefPtr nodeDef, GenContext& context)
705705
{
706-
NodeDefPtr nodeDef = node->getNodeDef();
707706
if (!nodeDef)
708707
{
709-
throw ExceptionShaderGenError("Could not find a nodedef for node '" + node->getName() + "'");
708+
throw ExceptionShaderGenError("Could not find a nodedef for node '" + name + "'");
710709
}
711710

711+
// Create this node in the graph.
712+
ShaderNodePtr newNode = ShaderNode::create(this, name, *nodeDef, context);
713+
_nodeMap[name] = newNode;
714+
_nodeOrder.push_back(newNode.get());
715+
716+
return newNode.get();
717+
}
718+
719+
ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
720+
{
721+
ConstNodeDefPtr nodeDef = node->getNodeDef();
722+
712723
// Create this node in the graph.
713724
context.pushParentNode(node);
714-
ShaderNodePtr newNode = ShaderNode::create(this, node->getName(), *nodeDef, context);
725+
ShaderNode* newNode = createNode(node->getName(), nodeDef, context);
715726
newNode->initialize(*node, *nodeDef, context);
716-
_nodeMap[node->getName()] = newNode;
717-
_nodeOrder.push_back(newNode.get());
718727
context.popParentNode();
719728

720729
// Check if any of the node inputs should be connected to the graph interface
@@ -757,7 +766,78 @@ ShaderNode* ShaderGraph::createNode(ConstNodePtr node, GenContext& context)
757766
// Apply color and unit transforms to each input.
758767
applyInputTransforms(node, newNode, context);
759768

760-
return newNode.get();
769+
return newNode;
770+
}
771+
772+
// Insert a new node between the output of the graph and its upstream connection, reconnecting the upstream to the specified input on
773+
// the new node, if present.
774+
ShaderNode* ShaderGraph::inlineNodeBeforeOutput(ShaderGraphOutputSocket* output, const std::string& newNodeName, const std::string& nodeDefName, const std::string& inputName, const std::string& outputName, GenContext& context)
775+
{
776+
auto nodeDef = _document->getNodeDef(nodeDefName);
777+
if (!nodeDef)
778+
{
779+
throw ExceptionShaderGenError("Cannot find NodeDef '"+nodeDefName+"' when inserting node '"+newNodeName+"'");
780+
}
781+
782+
// Check to see if the nodedef has the specified input/output ports
783+
OutputPtr nodeDefOutput = nodeDef->getOutput(outputName);
784+
if (!nodeDefOutput)
785+
{
786+
throw ExceptionShaderGenError("Output '"+outputName+"' not found on NodeDef '"+nodeDefName+"'");
787+
}
788+
789+
InputPtr nodeDefInput = nullptr;
790+
if (!inputName.empty())
791+
{
792+
// Only look for the input if we are given an inputName. It's valid to insert the node
793+
// without any upstream connection, and so an input name is not required.
794+
nodeDefInput = nodeDef->getInput(inputName);
795+
if (!nodeDefInput)
796+
{
797+
throw ExceptionShaderGenError("Input '"+inputName+"' not found on NodeDef '"+nodeDefName+"'");
798+
}
799+
}
800+
801+
// record the previously connected upstream
802+
auto originalUpstream = output->getConnection();
803+
804+
if (nodeDefInput && originalUpstream)
805+
{
806+
// if we're going to attempt to connect these - we need to check the types match
807+
// we do this before creating any new data
808+
if (nodeDefInput->getType() != originalUpstream->getType().getName())
809+
{
810+
throw ExceptionShaderGenError("Cannot connect ports of mismatched types '"+nodeDefInput->getType()+"' and '"+originalUpstream->getType().getName()+"' when inserting node");
811+
}
812+
}
813+
814+
// create the new node, and connect its output to the provided graph output
815+
auto newNode = createNode(newNodeName, nodeDef, context);
816+
if (!newNode)
817+
{
818+
throw ExceptionShaderGenError("Error while creating node '"+newNodeName+"' of type '"+nodeDefName+"'");
819+
}
820+
auto newNodeOutput = newNode->getOutput(outputName);
821+
newNodeOutput->setVariable(newNodeOutput->getFullName());
822+
output->makeConnection(newNodeOutput);
823+
824+
// update the type of the graph output port to match the new node output
825+
output->setType(newNodeOutput->getType());
826+
827+
// if there was an original upstream node connected to graph output
828+
// and we found the named input port - which means we were given a inputName (it was not empty string)
829+
// connect this to the new node at the provided input name.
830+
if (originalUpstream && nodeDefInput)
831+
{
832+
// update the variable name for the input and connect the original upstream
833+
// we already validated the types match above.
834+
auto newNodeInput = newNode->getInput(inputName);
835+
newNodeInput->setVariable(newNodeInput->getFullName());
836+
837+
originalUpstream->makeConnection(newNodeInput);
838+
}
839+
840+
return newNode;
761841
}
762842

763843
ShaderGraphInputSocket* ShaderGraph::addInputSocket(const string& name, TypeDesc type)

source/MaterialXGenShader/ShaderGraph.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,18 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
9393
const vector<ShaderGraphOutputSocket*>& getOutputSockets() const { return _inputOrder; }
9494

9595
/// Apply color and unit transforms to each input of a node.
96-
void applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context);
96+
void applyInputTransforms(ConstNodePtr node, ShaderNode* shaderNode, GenContext& context);
9797

9898
/// Create a new node in the graph
9999
ShaderNode* createNode(ConstNodePtr node, GenContext& context);
100100

101+
ShaderNode* inlineNodeBeforeOutput(ShaderGraphOutputSocket* output,
102+
const std::string& newNodeName,
103+
const std::string& nodeDefName,
104+
const std::string& inputName,
105+
const std::string& outputName,
106+
GenContext& context);
107+
101108
/// Add input sockets
102109
ShaderGraphInputSocket* addInputSocket(const string& name, TypeDesc type);
103110
[[deprecated]] ShaderGraphInputSocket* addInputSocket(const string& name, const TypeDesc* type) { return addInputSocket(name, *type); }
@@ -129,6 +136,11 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
129136
ElementPtr connectingElement,
130137
GenContext& context);
131138

139+
/// Create a new node in a graph from a node definition.
140+
/// Note - this does not initialize the node instance with any concrete values, but
141+
/// instead creates an empty instance of the provided node definition
142+
ShaderNode* createNode(const string& name, ConstNodeDefPtr nodeDef, GenContext& context);
143+
132144
/// Add a node to the graph
133145
void addNode(ShaderNodePtr node);
134146

0 commit comments

Comments
 (0)