@@ -92,6 +92,15 @@ const Tensor& SimpleExecutionEngine::incremental_forward(VariableIndex i) {
9292 string current_node_name; // Optionally used for debugging (reused).
9393 vector<const Tensor*> xs (16 ); // Container for arguments to nodes (reused).
9494
95+ unsigned size = 0 ;
96+ void * begin;
97+ for (unsigned j = num_nodes_evaluated; j <= i; ++j) {
98+ const Node* node = cg.nodes [j];
99+ auto rounded_n = pool_fxs->round_up_align (node->dim .size () * sizeof (float ));
100+ size += rounded_n;
101+ }
102+ begin = pool_fxs->allocate (size);
103+
95104 for (; num_nodes_evaluated <= i; ++num_nodes_evaluated) {
96105 const Node* node = cg.nodes [num_nodes_evaluated];
97106 if (autobatch_debug_flag) {
@@ -116,19 +125,19 @@ const Tensor& SimpleExecutionEngine::incremental_forward(VariableIndex i) {
116125 " SimpleExecutionEngine::incremental_forward" );
117126 node_fx.device = node->device ;
118127 node_fx.mem_pool = DeviceMempool::FXS;
119- // Get the memory to store f(xs)
120- auto & node_fx_pools = node_fx. device -> pools ;
121- node_fx. v = static_cast < float *>(
122- node_fx_pools[( int )DeviceMempool::FXS]-> allocate (
123- node-> dim . size () * sizeof ( float )));
128+ // Get the memory
129+ node_fx. v = static_cast < float *>(begin) ;
130+ auto rounded_n = pool_fxs-> round_up_align (node-> dim . size () * sizeof ( float ));
131+ begin += rounded_n;
132+
124133 if (node_fx.v == nullptr )
125134 DYNET_RUNTIME_ERR (" Ran out of memory when executing node " <<
126135 num_nodes_evaluated);
127136 void * aux_mem = nullptr ;
128137 // Is the node requesting extra memory?
129138 size_t aux_size = node->aux_storage_size ();
130139 if (aux_size) {
131- aux_mem = node_fx_pools[( int )DeviceMempool::FXS] ->allocate (aux_size);
140+ aux_mem = pool_fxs ->allocate (aux_size);
132141 if (aux_mem == nullptr )
133142 DYNET_RUNTIME_ERR (" Ran out of auxiliary memory when executing node "
134143 << num_nodes_evaluated);
@@ -161,30 +170,37 @@ void SimpleExecutionEngine::backward(VariableIndex from_where, bool full) {
161170
162171 const unsigned num_nodes = from_where + 1 ;
163172 ndEdfs.resize (num_nodes);
164- const vector<Device*> &devices = device_manager->get_devices ();
165- for (Device* device : devices)
166- device->pools [(int )DeviceMempool::DEDFS]->free ();
173+ pool_dEdfs->free ();
167174
168175 // This loop allocates memory on the appropriate devices for the nodes whose
169176 // derivatives will be computed.
177+ // This assumes all of these use the same device!
178+ unsigned size = 0 ;
179+ void * begin;
180+ for (unsigned i = 0 ; i < num_nodes; ++i) {
181+ const Node* node = cg.nodes [i];
182+ auto rounded_n = pool_dEdfs->round_up_align (node->dim .size () * sizeof (float ));
183+ size += rounded_n;
184+ }
185+ begin = pool_dEdfs->allocate (size);
186+ pool_dEdfs->zero_allocated_memory ();
187+
170188 for (unsigned i = 0 ; i < num_nodes; ++i) {
171189 const auto dim = nfxs[i].d ;
172190 auto & node_dEdfx = ndEdfs[i];
173191 node_dEdfx.d = dim;
174192 node_dEdfx.device = nfxs[i].device ;
175193 node_dEdfx.mem_pool = DeviceMempool::DEDFS;
176- node_dEdfx.v = static_cast <float *>(
177- node_dEdfx.device ->pools [(int )DeviceMempool::DEDFS]->allocate (
178- dim.size () * sizeof (float )));
194+ node_dEdfx.v = static_cast <float *>(begin);
195+ auto rounded_n = pool_dEdfs->round_up_align (dim.size () * sizeof (float ));
196+ begin += rounded_n;
197+
179198 if (node_dEdfx.v == nullptr ) {
180199 DYNET_RUNTIME_ERR (
181200 " out of memory while attempting to allocate space for "
182201 " derivatives of node " << i);
183202 }
184203 }
185- // Zero all derivative memory (which is contiguous on each device)
186- for (Device* device : devices)
187- device->pools [(int )DeviceMempool::DEDFS]->zero_allocated_memory ();
188204
189205 // initialize dE/dE = 1
190206 ndEdfs.back ().v = cg.nodes .back ()->device ->kSCALAR_ONE ;
0 commit comments