Assume you have the following data in the form of a `csv`

-file. The content looks something like this

`,Action,Comedy,Horror 1,650,819, ,76,63, 2,,462,19 ,,18,96 3,652,457,18 ,75,36,89 `

which can be interpreted as a table of the form

` Action Comedy Horror 1 650 819 76 63 2 462 19 18 96 3 652 457 18 75 36 89 `

The goal was to write a function that takes a `lst`

with genre names as elements in form of a `str`

and returns a scatter plot of the data, where the data that should appear on the scatter plot is in the second row of every index (`76, 63 ,`

and `, 18, 96`

and `75, 36, 89`

). The function should be able to distinguish between two-dimensional and three-dimensional scatter plots depending on the input. The code I wrote for this is

`from pandas import DataFrame import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D def genre_scatter(lst): """ Creates an scatter plot using the data from genre_scores.csv. :param lst: a list with names of the genres considered :return: saves a pdf-file to the folder Fig with the name gen_1_ge_2.pdf """ # First we need to determine the right columns of genre_user_scores. first_row = [row for row in reader(open('genre_user_scores.csv', 'r'))][0] index = [first_row.index(x) for x in lst] # Get the relevant data in the form of a DataFrame. # Please note that the first row of data for every index is not necessary for this task. data = DataFrame.from_csv('genre_user_scores.csv') gen_scores = [data.dropna().iloc[1::2, ind - 1].transpose() for ind in index] # rewrite the values in an flattened array for plotting coordinates = [gen.as_matrix().flatten() for gen in gen_scores] # Plot the results fig = plt.figure() if len(coordinates) == 2: plt.scatter(*coordinates) plt.text(70, 110, "pearson={}".format(round(pearson_coeff(coordinates[0], coordinates[1]), 3))) plt.xlabel(lst[0]) plt.ylabel(lst[1]) plt.savefig("Fig/{}_{}.pdf".format(*lst)) else: ax = fig.add_subplot(111, projection='3d') ax.scatter(*coordinates) ax.update({'xlabel': lst[0], 'ylabel': lst[1], 'zlabel': lst[2]}) plt.savefig("Fig/{}_{}_{}.pdf".format(*lst)) plt.show() plt.close("all") if __name__ == "__main__": genre_scatter(['Action', 'Horror', 'Comedy']) `

The code works and I’m happy with the output but there are a few things that bug me and I’m not sure if I used them right.

- I’m not incredibly familiar with list comprehension (I think that is what you call expressions of the form
`[x for x in list]`

. Please correct me if I’m wrong) and haven’t used them very often, so I’m not quite sure if this here was the right approach for the problem. My biggest concern is the first use of this kind of expression, where I basically need the first row of the csv file but create a list with all the rows only to use the first… Is there a smarter way to do this?
- Is there a better way to label the axes? Ideally some function where I just could pass the
`*lst`

argument?
- I’d like to implement something that makes sure that
`lst`

isn’t longer than three elements (since four dimensional plots aren’t really a thing). The only way I know to do this is `assert len(lst) <=3`

, which gets the job done but it would be nice if it also could raise a useful error message. Any tips on how to pull that off?

Any further comments about the code are of course also welcome.

**Note:** I’m not sure if the beginner tag is appropriate here. I usually only do computational math/physics stuff, so I’m completely new to the world of DataFrames, dictionaries, list comprehension, etc.