2 years ago
#40089
Jahnavi
Style loss shows TypeError: 'function' object is not subscriptable
I'm trying to implement neural style transfer using the pre-trained vgg19 model in google colab. I'm getting an error while running this section of code. Content loss is printing the correct value but I'm not sure what is wrong with style loss.
def content_loss(target_conv4_2,content_conv4_2):
loss=torch.mean((target_conv4_2-content_conv4_2)**2)
return loss
style_grams = {layer : gram_matrix(style_f[layer]) for layer in style_f}
def style_loss(style_weights,target_features,style_grams):
loss = 0
for layer in style_weights:
target_f = target_features[layer]
target_gram = gram_matrix[target_f]
style_gram = style_grams[layer]
b,c,h,w = target_f.shape
layer.loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
loss += layer_loss/(c*h*w)
return loss
target = content_p.clone().requires_grad_(True).to(device)
target_f = get_features(target,vgg)
print("Content Loss: ",content_loss(target_f['conv4_2'],content_f['conv4_2']))
print("Style Loss: ",style_loss(style_weights,target_f,style_grams))
This is the error:
Content Loss: tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-101-0b8ad8e6d456> in <module>()
3 #style_grams = {layer : gram_matrix(style_f[layer]) for layer in style_f}
4 print("Content Loss: ",content_loss(target_frittata['conv4_2'],content_f['conv4_2']))
----> 5 print("Style Loss: ",style_loss(style_weights,target_frittata,style_grams))
<ipython-input-98-25679c9fd886> in style_loss(style_weights, target_features, style_grams)
5 for layer in style_weights:
6 target_f = target_features[layer]
----> 7 target_gram = gram_matrix[target_f]
8 style_gram = style_grams[layer]
9 b,c,h,w = target_f.shape
TypeError: 'function' object is not subscriptable
According to this answer, its because of 2 objects with the same name, but I have no idea where's the error.
python
numpy
tensorflow
vgg-net
torchvision
0 Answers
Your Answer